Skip to content

Commit

Permalink
Update example DAG for AI Platform operators (#9727)
Browse files Browse the repository at this point in the history
  • Loading branch information
vuppalli authored Jul 9, 2020
1 parent 13a827d commit b230566
Showing 1 changed file with 16 additions and 19 deletions.
35 changes: 16 additions & 19 deletions airflow/providers/google/cloud/example_dags/example_mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
from airflow import models
from airflow.operators.bash import BashOperator
from airflow.providers.google.cloud.operators.mlengine import (
MLEngineCreateVersionOperator, MLEngineDeleteModelOperator, MLEngineDeleteVersionOperator,
MLEngineListVersionsOperator, MLEngineManageModelOperator, MLEngineSetDefaultVersionOperator,
MLEngineStartBatchPredictionJobOperator, MLEngineStartTrainingJobOperator,
MLEngineCreateModelOperator, MLEngineCreateVersionOperator, MLEngineDeleteModelOperator,
MLEngineDeleteVersionOperator, MLEngineGetModelOperator, MLEngineListVersionsOperator,
MLEngineSetDefaultVersionOperator, MLEngineStartBatchPredictionJobOperator,
MLEngineStartTrainingJobOperator,
)
from airflow.providers.google.cloud.utils import mlengine_operator_utils
from airflow.utils.dates import days_ago
Expand Down Expand Up @@ -66,30 +67,26 @@
project_id=PROJECT_ID,
region="us-central1",
job_id="training-job-{{ ts_nodash }}-{{ params.model_name }}",
runtime_version="1.14",
python_version="3.5",
runtime_version="1.15",
python_version="3.7",
job_dir=JOB_DIR,
package_uris=[TRAINER_URI],
training_python_module=TRAINER_PY_MODULE,
training_args=[],
)

create_model = MLEngineManageModelOperator(
create_model = MLEngineCreateModelOperator(
task_id="create-model",
project_id=PROJECT_ID,
operation='create',
model={
"name": MODEL_NAME,
},
)

get_model = MLEngineManageModelOperator(
get_model = MLEngineGetModelOperator(
task_id="get-model",
project_id=PROJECT_ID,
operation="get",
model={
"name": MODEL_NAME,
}
model_name=MODEL_NAME,
)

get_model_result = BashOperator(
Expand All @@ -105,10 +102,10 @@
"name": "v1",
"description": "First-version",
"deployment_uri": '{}/keras_export/'.format(JOB_DIR),
"runtime_version": "1.14",
"runtime_version": "1.15",
"machineType": "mls1-c1-m2",
"framework": "TENSORFLOW",
"pythonVersion": "3.5"
"pythonVersion": "3.7"
}
)

Expand All @@ -120,10 +117,10 @@
"name": "v2",
"description": "Second version",
"deployment_uri": SAVED_MODEL_PATH,
"runtime_version": "1.14",
"runtime_version": "1.15",
"machineType": "mls1-c1-m2",
"framework": "TENSORFLOW",
"pythonVersion": "3.5"
"pythonVersion": "3.7"
}
)

Expand All @@ -148,7 +145,7 @@
prediction = MLEngineStartBatchPredictionJobOperator(
task_id="prediction",
project_id=PROJECT_ID,
job_id="prediciton-{{ ts_nodash }}-{{ params.model_name }}",
job_id="prediction-{{ ts_nodash }}-{{ params.model_name }}",
region="us-central1",
model_name=MODEL_NAME,
data_format="TEXT",
Expand Down Expand Up @@ -203,13 +200,13 @@ def validate_err_and_count(summary: Dict) -> Dict:
return summary

evaluate_prediction, evaluate_summary, evaluate_validation = mlengine_operator_utils.create_evaluate_ops(
task_prefix="evalueate-ops", # pylint: disable=too-many-arguments
task_prefix="evaluate-ops",
data_format="TEXT",
input_paths=[PREDICTION_INPUT],
prediction_path=PREDICTION_OUTPUT,
metric_fn_and_keys=get_metric_fn_and_keys(),
validate_fn=validate_err_and_count,
batch_prediction_job_id="evalueate-ops-{{ ts_nodash }}-{{ params.model_name }}",
batch_prediction_job_id="evaluate-ops-{{ ts_nodash }}-{{ params.model_name }}",
project_id=PROJECT_ID,
region="us-central1",
dataflow_options={
Expand Down

0 comments on commit b230566

Please sign in to comment.
  翻译: