Skip to content

Commit

Permalink
Override project in dataprocSubmitJobOperator (#14981)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamWheating authored Mar 28, 2021
1 parent ec962b0 commit 099c490
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
7 changes: 5 additions & 2 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,9 @@ class DataprocJobBaseOperator(BaseOperator):
:type job_name: str
:param cluster_name: The name of the DataProc cluster.
:type cluster_name: str
:param project_id: The ID of the Google Cloud project the cluster belongs to,
if not specified the project will be inferred from the provided GCP connection.
:type project_id: str
:param dataproc_properties: Map for the Hive properties. Ideal to put in
default arguments (templated)
:type dataproc_properties: dict
Expand Down Expand Up @@ -912,6 +915,7 @@ def __init__(
*,
job_name: str = '{{task.task_id}}_{{ds_nodash}}',
cluster_name: str = "cluster-1",
project_id: Optional[str] = None,
dataproc_properties: Optional[Dict] = None,
dataproc_jars: Optional[List[str]] = None,
gcp_conn_id: str = 'google_cloud_default',
Expand Down Expand Up @@ -943,9 +947,8 @@ def __init__(

self.job_error_states = job_error_states if job_error_states is not None else {'ERROR'}
self.impersonation_chain = impersonation_chain

self.hook = DataprocHook(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain)
self.project_id = self.hook.project_id
self.project_id = self.hook.project_id if project_id is None else project_id
self.job_template = None
self.job = None
self.dataproc_job_id = None
Expand Down
32 changes: 32 additions & 0 deletions tests/providers/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,12 @@ class TestDataProcSparkSqlOperator(unittest.TestCase):
"labels": {"airflow-version": AIRFLOW_VERSION},
"spark_sql_job": {"query_list": {"queries": [query]}, "script_variables": variables},
}
other_project_job = {
"reference": {"project_id": "other-project", "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"spark_sql_job": {"query_list": {"queries": [query]}, "script_variables": variables},
}

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_deprecation_warning(self, mock_hook):
Expand Down Expand Up @@ -813,6 +819,32 @@ def test_execute(self, mock_hook, mock_uuid):
job_id=self.job_id, location=GCP_LOCATION, project_id=GCP_PROJECT
)

@mock.patch(DATAPROC_PATH.format("uuid.uuid4"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_override_project_id(self, mock_hook, mock_uuid):
mock_uuid.return_value = self.job_id
mock_hook.return_value.project_id = GCP_PROJECT
mock_hook.return_value.wait_for_job.return_value = None
mock_hook.return_value.submit_job.return_value.reference.job_id = self.job_id

op = DataprocSubmitSparkSqlJobOperator(
project_id="other-project",
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
query=self.query,
variables=self.variables,
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={})
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.submit_job.assert_called_once_with(
project_id="other-project", job=self.other_project_job, location=GCP_LOCATION
)
mock_hook.return_value.wait_for_job.assert_called_once_with(
job_id=self.job_id, location=GCP_LOCATION, project_id="other-project"
)

@mock.patch(DATAPROC_PATH.format("uuid.uuid4"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_builder(self, mock_hook, mock_uuid):
Expand Down

0 comments on commit 099c490

Please sign in to comment.
  翻译: