Skip to content

Commit

Permalink
Add service_account to Google ML Engine operator (#11619)
Browse files Browse the repository at this point in the history
  • Loading branch information
DBCerigo authored Oct 19, 2020
1 parent ae06ad0 commit 2d854c3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
13 changes: 13 additions & 0 deletions airflow/providers/google/cloud/operators/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,13 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
:param job_dir: A Google Cloud Storage path in which to store training
outputs and other data needed for training. (templated)
:type job_dir: str
:param service_account: Optional service account to use when running the training application.
(templated)
The specified service account must have the `iam.serviceAccounts.actAs` role. The
Google-managed Cloud ML Engine service account must have the `iam.serviceAccountAdmin` role
for the specified service account.
If set to None or missing, the Google-managed Cloud ML Engine service account will be used.
:type service_account: str
:param project_id: The Google Cloud project name within which MLEngine training job should run.
If set to None or missing, the default project_id from the Google Cloud connection is used.
(templated)
Expand Down Expand Up @@ -1156,6 +1163,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
'_runtime_version',
'_python_version',
'_job_dir',
'_service_account',
'_impersonation_chain',
]

Expand All @@ -1176,6 +1184,7 @@ def __init__(
runtime_version: Optional[str] = None,
python_version: Optional[str] = None,
job_dir: Optional[str] = None,
service_account: Optional[str] = None,
project_id: Optional[str] = None,
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
Expand All @@ -1197,6 +1206,7 @@ def __init__(
self._runtime_version = runtime_version
self._python_version = python_version
self._job_dir = job_dir
self._service_account = service_account
self._gcp_conn_id = gcp_conn_id
self._delegate_to = delegate_to
self._mode = mode
Expand Down Expand Up @@ -1244,6 +1254,9 @@ def execute(self, context):
if self._job_dir:
training_request['trainingInput']['jobDir'] = self._job_dir

if self._service_account:
training_request['trainingInput']['serviceAccount'] = self._service_account

if self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM":
training_request['trainingInput']['masterType'] = self._master_type

Expand Down
2 changes: 2 additions & 0 deletions tests/providers/google/cloud/operators/test_mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def test_success_create_training_job_with_optional_args(self, mock_hook):
training_input['trainingInput']['runtimeVersion'] = '1.6'
training_input['trainingInput']['pythonVersion'] = '3.5'
training_input['trainingInput']['jobDir'] = 'gs://some-bucket/jobs/test_training'
training_input['trainingInput']['serviceAccount'] = 'test@serviceaccount.com'

success_response = self.TRAINING_INPUT.copy()
success_response['state'] = 'SUCCEEDED'
Expand All @@ -423,6 +424,7 @@ def test_success_create_training_job_with_optional_args(self, mock_hook):
runtime_version='1.6',
python_version='3.5',
job_dir='gs://some-bucket/jobs/test_training',
service_account='test@serviceaccount.com',
**self.TRAINING_DEFAULT_ARGS,
)
training_op.execute(MagicMock())
Expand Down

0 comments on commit 2d854c3

Please sign in to comment.
  翻译: