Skip to content

Commit

Permalink
Run Dataflow for ML Engine summary in venv (#7809)
Browse files Browse the repository at this point in the history
  • Loading branch information
mik-laj committed Mar 24, 2020
1 parent 0c6af43 commit 1982c3f
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 22 deletions.
10 changes: 6 additions & 4 deletions airflow/providers/google/cloud/utils/mlengine_operator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,20 +225,22 @@ def validate_err_and_count(summary):
metric_fn_encoded = base64.b64encode(dill.dumps(metric_fn, recurse=True)).decode()
evaluate_summary = DataflowCreatePythonJobOperator(
task_id=(task_prefix + "-summary"),
py_options=["-m"],
py_file="airflow.providers.google.cloud.utils.mlengine_prediction_summary",
py_file=os.path.join(os.path.dirname(__file__), 'mlengine_prediction_summary.py'),
dataflow_default_options=dataflow_options,
options={
"prediction_path": prediction_path,
"metric_fn_encoded": metric_fn_encoded,
"metric_keys": ','.join(metric_keys)
},
py_interpreter=py_interpreter,
py_requirements=[
'apache-beam[gcp]>=2.14.0'
],
dag=dag)
evaluate_summary.set_upstream(evaluate_prediction)

def apply_validate_fn(*args, **kwargs):
prediction_path = kwargs["templates_dict"]["prediction_path"]
def apply_validate_fn(*args, templates_dict, **kwargs):
prediction_path = templates_dict["prediction_path"]
scheme, bucket, obj, _, _ = urlsplit(prediction_path)
if scheme != "gs" or not bucket or not obj:
raise ValueError("Wrong format prediction_path: {}".format(prediction_path))
Expand Down
32 changes: 17 additions & 15 deletions airflow/providers/google/cloud/utils/mlengine_prediction_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def metric_fn(inst):
import argparse
import base64
import json
import logging
import os

import apache_beam as beam
Expand Down Expand Up @@ -156,23 +157,24 @@ def run(argv=None):
raise ValueError("--metric_fn_encoded must be an encoded callable.")
metric_keys = known_args.metric_keys.split(",")

with beam.Pipeline(
options=beam.pipeline.PipelineOptions(pipeline_args)) as pipe:
# This is apache-beam ptransform's convention
with beam.Pipeline(options=beam.pipeline.PipelineOptions(pipeline_args)) as pipe:
# pylint: disable=no-value-for-parameter
_ = (pipe
| "ReadPredictionResult" >> beam.io.ReadFromText(
os.path.join(known_args.prediction_path,
"prediction.results-*-of-*"),
coder=JsonCoder())
| "Summary" >> MakeSummary(metric_fn, metric_keys)
| "Write" >> beam.io.WriteToText(
os.path.join(known_args.prediction_path,
"prediction.summary.json"),
shard_name_template='', # without trailing -NNNNN-of-NNNNN.
coder=JsonCoder()))
# pylint: enable=no-value-for-parameter
prediction_result_pattern = os.path.join(known_args.prediction_path, "prediction.results-*-of-*")
prediction_summary_path = os.path.join(known_args.prediction_path, "prediction.summary.json")
# This is apache-beam ptransform's convention
_ = (
pipe | "ReadPredictionResult" >> beam.io.ReadFromText(
prediction_result_pattern, coder=JsonCoder())
| "Summary" >> MakeSummary(metric_fn, metric_keys)
| "Write" >> beam.io.WriteToText(
prediction_summary_path,
shard_name_template='', # without trailing -NNNNN-of-NNNNN.
coder=JsonCoder())
)


if __name__ == "__main__":
# Dataflow does not print anything on the screen by default. Good practice says to configure the logger
# to be able to track the progress. This code is run in a separate process, so it's safe.
logging.getLogger().setLevel(logging.INFO)
run()
8 changes: 5 additions & 3 deletions tests/providers/google/cloud/operators/test_mlengine_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import unittest
from unittest.mock import ANY, patch

import mock

from airflow.exceptions import AirflowException
from airflow.models.dag import DAG
from airflow.providers.google.cloud.utils import mlengine_operator_utils
Expand Down Expand Up @@ -110,10 +112,10 @@ def test_successful_run(self):
'metric_keys': 'err',
'metric_fn_encoded': self.metric_fn_encoded,
},
dataflow='airflow.providers.google.cloud.utils.mlengine_prediction_summary',
py_options=['-m'],
dataflow=mock.ANY,
py_options=[],
py_requirements=['apache-beam[gcp]>=2.14.0'],
py_interpreter='python3',
py_requirements=[],
py_system_site_packages=False,
on_new_job_id_callback=ANY
)
Expand Down

0 comments on commit 1982c3f

Please sign in to comment.
  翻译: