Skip to content

Commit

Permalink
Fix doc errors in google provider files. (#11713)
Browse files Browse the repository at this point in the history
These files aren't _currently_ rendered/parsed by autoapi, but I was
exploring making them parseable and ran in to some sphinx formatting
errors.

The `Args:` change is because pydocstyle thinks that is a special word, but
we don't want it to be.
  • Loading branch information
ashb committed Oct 21, 2020
1 parent 53e6062 commit 2bfc53b
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 84 deletions.
41 changes: 23 additions & 18 deletions airflow/providers/google/cloud/utils/mlengine_operator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def create_evaluate_ops( # pylint: disable=too-many-arguments
Callers will provide two python callables, metric_fn and validate_fn, in
order to customize the evaluation behavior as they wish.
- metric_fn receives a dictionary per instance derived from json in the
batch prediction result. The keys might vary depending on the model.
It should return a tuple of metrics.
Expand All @@ -93,24 +94,26 @@ def create_evaluate_ops( # pylint: disable=too-many-arguments
Typical examples are like this:
def get_metric_fn_and_keys():
import math # imports should be outside of the metric_fn below.
def error_and_squared_error(inst):
label = float(inst['input_label'])
classes = float(inst['classes']) # 0 or 1
err = abs(classes-label)
squared_err = math.pow(classes-label, 2)
return (err, squared_err) # returns a tuple.
return error_and_squared_error, ['err', 'mse'] # key order must match.
def validate_err_and_count(summary):
if summary['err'] > 0.2:
raise ValueError('Too high err>0.2; summary=%s' % summary)
if summary['mse'] > 0.05:
raise ValueError('Too high mse>0.05; summary=%s' % summary)
if summary['count'] < 1000:
raise ValueError('Too few instances<1000; summary=%s' % summary)
return summary
.. code-block:: python
def get_metric_fn_and_keys():
import math # imports should be outside of the metric_fn below.
def error_and_squared_error(inst):
label = float(inst['input_label'])
classes = float(inst['classes']) # 0 or 1
err = abs(classes-label)
squared_err = math.pow(classes-label, 2)
return (err, squared_err) # returns a tuple.
return error_and_squared_error, ['err', 'mse'] # key order must match.
def validate_err_and_count(summary):
if summary['err'] > 0.2:
raise ValueError('Too high err>0.2; summary=%s' % summary)
if summary['mse'] > 0.05:
raise ValueError('Too high mse>0.05; summary=%s' % summary)
if summary['count'] < 1000:
raise ValueError('Too few instances<1000; summary=%s' % summary)
return summary
For the details on the other BatchPrediction-related arguments (project_id,
job_id, region, data_format, input_paths, prediction_path, model_uri),
Expand All @@ -131,8 +134,10 @@ def validate_err_and_count(summary):
:type prediction_path: str
:param metric_fn_and_keys: a tuple of metric_fn and metric_keys:
- metric_fn is a function that accepts a dictionary (for an instance),
and returns a tuple of metric(s) that it calculates.
- metric_keys is a list of strings to denote the key of each metric.
:type metric_fn_and_keys: tuple of a function and a list[str]
Expand Down
142 changes: 80 additions & 62 deletions airflow/providers/google/cloud/utils/mlengine_prediction_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,71 +16,89 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
A template called by DataFlowPythonOperator to summarize BatchPrediction.
"""A template called by DataFlowPythonOperator to summarize BatchPrediction.
It accepts a user function to calculate the metric(s) per instance in
the prediction results, then aggregates to output as a summary.
Args:
--prediction_path:
The GCS folder that contains BatchPrediction results, containing
prediction.results-NNNNN-of-NNNNN files in the json format.
Output will be also stored in this folder, as 'prediction.summary.json'.
--metric_fn_encoded:
An encoded function that calculates and returns a tuple of metric(s)
for a given instance (as a dictionary). It should be encoded
via base64.b64encode(dill.dumps(fn, recurse=True)).
--metric_keys:
A comma-separated key(s) of the aggregated metric(s) in the summary
output. The order and the size of the keys must match to the output
of metric_fn.
The summary will have an additional key, 'count', to represent the
total number of instances, so the keys shouldn't include 'count'.
# Usage example:
from airflow.providers.google.cloud.operators.dataflow import DataflowCreatePythonJobOperator
def get_metric_fn():
import math # all imports must be outside of the function to be passed.
def metric_fn(inst):
label = float(inst["input_label"])
classes = float(inst["classes"])
prediction = float(inst["scores"][1])
log_loss = math.log(1 + math.exp(
-(label * 2 - 1) * math.log(prediction / (1 - prediction))))
squared_err = (classes-label)**2
return (log_loss, squared_err)
return metric_fn
metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True))
DataflowCreatePythonJobOperator(
task_id="summary-prediction",
py_options=["-m"],
py_file="airflow.providers.google.cloud.utils.mlengine_prediction_summary",
options={
"prediction_path": prediction_path,
"metric_fn_encoded": metric_fn_encoded,
"metric_keys": "log_loss,mse"
},
dataflow_default_options={
"project": "xxx", "region": "us-east1",
"staging_location": "gs://yy", "temp_location": "gs://zz",
})
>> dag
# When the input file is like the following:
{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]}
{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]}
{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]}
{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]}
# The output file will be:
{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25}
# To test outside of the dag:
subprocess.check_call(["python",
"-m",
"airflow.providers.google.cloud.utils.mlengine_prediction_summary",
"--prediction_path=gs://...",
"--metric_fn_encoded=" + metric_fn_encoded,
"--metric_keys=log_loss,mse",
"--runner=DataflowRunner",
"--staging_location=gs://...",
"--temp_location=gs://...",
])
It accepts the following arguments:
- ``--prediction_path``:
The GCS folder that contains BatchPrediction results, containing
prediction.results-NNNNN-of-NNNNN files in the json format.
Output will be also stored in this folder, as 'prediction.summary.json'.
- ``--metric_fn_encoded``:
An encoded function that calculates and returns a tuple of metric(s)
for a given instance (as a dictionary). It should be encoded
via base64.b64encode(dill.dumps(fn, recurse=True)).
- ``--metric_keys``:
A comma-separated key(s) of the aggregated metric(s) in the summary
output. The order and the size of the keys must match to the output
of metric_fn.
The summary will have an additional key, 'count', to represent the
total number of instances, so the keys shouldn't include 'count'.
Usage example:
.. code-block: python
from airflow.providers.google.cloud.operators.dataflow import DataflowCreatePythonJobOperator
def get_metric_fn():
import math # all imports must be outside of the function to be passed.
def metric_fn(inst):
label = float(inst["input_label"])
classes = float(inst["classes"])
prediction = float(inst["scores"][1])
log_loss = math.log(1 + math.exp(
-(label * 2 - 1) * math.log(prediction / (1 - prediction))))
squared_err = (classes-label)**2
return (log_loss, squared_err)
return metric_fn
metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True))
DataflowCreatePythonJobOperator(
task_id="summary-prediction",
py_options=["-m"],
py_file="airflow.providers.google.cloud.utils.mlengine_prediction_summary",
options={
"prediction_path": prediction_path,
"metric_fn_encoded": metric_fn_encoded,
"metric_keys": "log_loss,mse"
},
dataflow_default_options={
"project": "xxx", "region": "us-east1",
"staging_location": "gs://yy", "temp_location": "gs://zz",
}
) >> dag
When the input file is like the following::
{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]}
{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]}
{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]}
{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]}
The output file will be::
{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25}
To test outside of the dag:
.. code-block:: python
subprocess.check_call(["python",
"-m",
"airflow.providers.google.cloud.utils.mlengine_prediction_summary",
"--prediction_path=gs://...",
"--metric_fn_encoded=" + metric_fn_encoded,
"--metric_keys=log_loss,mse",
"--runner=DataflowRunner",
"--staging_location=gs://...",
"--temp_location=gs://...",
])
"""

import argparse
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/google/common/utils/id_token_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def _load_credentials_from_file(
:param filename: The full path to the credentials file.
:type filename: str
:return Loaded credentials
:rtype google.auth.credentials.Credentials
:return: Loaded credentials
:rtype: google.auth.credentials.Credentials
:raise google.auth.exceptions.DefaultCredentialsError: if the file is in the wrong format or is missing.
"""
if not os.path.exists(filename):
Expand Down Expand Up @@ -184,8 +184,8 @@ def get_default_id_token_credentials(
is running on Compute Engine. If not specified, then it will use the standard library http client
to make requests.
:type request: google.auth.transport.Request
:return the current environment's credentials.
:rtype google.auth.credentials.Credentials
:return: the current environment's credentials.
:rtype: google.auth.credentials.Credentials
:raises ~google.auth.exceptions.DefaultCredentialsError:
If no credentials were found, or if the credentials found were invalid.
"""
Expand Down

0 comments on commit 2bfc53b

Please sign in to comment.
  翻译: