Skip to content

Commit

Permalink
Refactor Dataproc Trigger (#29364)
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro committed Feb 20, 2023
1 parent ee0a56a commit 6ef5ba9
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 125 deletions.
4 changes: 1 addition & 3 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,12 +877,10 @@ def execute(self, context: Context) -> None:
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
request_id=self.request_id,
retry=self.retry,
end_time=end_time,
metadata=self.metadata,
impersonation_chain=self.impersonation_chain,
polling_interval=self.polling_interval_seconds,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)
Expand Down
191 changes: 71 additions & 120 deletions airflow/providers/google/cloud/triggers/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import asyncio
import time
import warnings
from typing import Any, AsyncIterator, Sequence

from google.api_core.exceptions import NotFound
Expand All @@ -31,40 +30,58 @@
from airflow.triggers.base import BaseTrigger, TriggerEvent


class DataprocSubmitTrigger(BaseTrigger):
"""
Trigger that periodically polls information from Dataproc API to verify job status.
Implementation leverages asynchronous transport.
"""
class DataprocBaseTrigger(BaseTrigger):
"""Base class for Dataproc triggers"""

def __init__(
self,
job_id: str,
region: str,
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
delegate_to: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
polling_interval_seconds: int = 30,
):
super().__init__()
self.region = region
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.job_id = job_id
self.project_id = project_id
self.region = region
self.polling_interval_seconds = polling_interval_seconds
if delegate_to:
warnings.warn(
"'delegate_to' parameter is deprecated, please use 'impersonation_chain'", DeprecationWarning
)
self.delegate_to = delegate_to
self.hook = DataprocAsyncHook(
delegate_to=self.delegate_to,

def get_async_hook(self):
return DataprocAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
delegate_to=self.delegate_to,
)


class DataprocSubmitTrigger(DataprocBaseTrigger):
"""
DataprocSubmitTrigger run on the trigger worker to perform create Build operation
:param job_id: The ID of a Dataproc job.
:param project_id: Google Cloud Project where the job is running
:param region: The Cloud Dataproc region in which to handle the request.
:param gcp_conn_id: Optional, the connection ID used to connect to Google Cloud Platform.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param polling_interval_seconds: polling period in seconds to check for the status
"""

def __init__(self, job_id: str, delegate_to: str | None = None, **kwargs):
self.job_id = job_id
self.delegate_to = delegate_to
super().__init__(delegate_to=self.delegate_to, **kwargs)

def serialize(self):
return (
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger",
Expand All @@ -81,7 +98,9 @@ def serialize(self):

async def run(self):
while True:
job = await self.hook.get_job(project_id=self.project_id, region=self.region, job_id=self.job_id)
job = await self.get_async_hook().get_job(
project_id=self.project_id, region=self.region, job_id=self.job_id
)
state = job.status.state
self.log.info("Dataproc job: %s is in state: %s", self.job_id, state)
if state in (JobStatus.State.ERROR, JobStatus.State.DONE, JobStatus.State.CANCELLED):
Expand All @@ -93,28 +112,28 @@ async def run(self):
yield TriggerEvent({"job_id": self.job_id, "job_state": state})


class DataprocClusterTrigger(BaseTrigger):
class DataprocClusterTrigger(DataprocBaseTrigger):
"""
Trigger that periodically polls information from Dataproc API to verify status.
Implementation leverages asynchronous transport.
DataprocClusterTrigger run on the trigger worker to perform create Build operation
:param cluster_name: The name of the cluster.
:param project_id: Google Cloud Project where the job is running
:param region: The Cloud Dataproc region in which to handle the request.
:param gcp_conn_id: Optional, the connection ID used to connect to Google Cloud Platform.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param polling_interval_seconds: polling period in seconds to check for the status
"""

def __init__(
self,
cluster_name: str,
region: str,
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
polling_interval_seconds: int = 10,
):
super().__init__()
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
def __init__(self, cluster_name: str, **kwargs):
super().__init__(**kwargs)
self.cluster_name = cluster_name
self.project_id = project_id
self.region = region
self.polling_interval_seconds = polling_interval_seconds

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
Expand All @@ -130,9 +149,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
hook = self._get_hook()
while True:
cluster = await hook.get_cluster(
cluster = await self.get_async_hook().get_cluster(
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
)
state = cluster.status.state
Expand All @@ -146,14 +164,8 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:
await asyncio.sleep(self.polling_interval_seconds)
yield TriggerEvent({"cluster_name": self.cluster_name, "cluster_state": state, "cluster": cluster})

def _get_hook(self) -> DataprocAsyncHook:
return DataprocAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)


class DataprocBatchTrigger(BaseTrigger):
class DataprocBatchTrigger(DataprocBaseTrigger):
"""
DataprocCreateBatchTrigger run on the trigger worker to perform create Build operation
Expand All @@ -172,22 +184,9 @@ class DataprocBatchTrigger(BaseTrigger):
:param polling_interval_seconds: polling period in seconds to check for the status
"""

def __init__(
self,
batch_id: str,
region: str,
project_id: str | None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
polling_interval_seconds: float = 5.0,
):
super().__init__()
def __init__(self, batch_id: str, **kwargs):
super().__init__(**kwargs)
self.batch_id = batch_id
self.project_id = project_id
self.region = region
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.polling_interval_seconds = polling_interval_seconds

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes DataprocBatchTrigger arguments and classpath."""
Expand All @@ -204,13 +203,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
)

async def run(self):
hook = DataprocAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

while True:
batch = await hook.get_batch(
batch = await self.get_async_hook().get_batch(
project_id=self.project_id, region=self.region, batch_id=self.batch_id
)
state = batch.state
Expand All @@ -223,9 +217,9 @@ async def run(self):
yield TriggerEvent({"batch_id": self.batch_id, "batch_state": state})


class DataprocDeleteClusterTrigger(BaseTrigger):
class DataprocDeleteClusterTrigger(DataprocBaseTrigger):
"""
Asynchronously checks the status of a cluster.
DataprocDeleteClusterTrigger run on the trigger worker to perform delete cluster operation.
:param cluster_name: The name of the cluster
:param end_time: Time in second left to check the cluster status
Expand All @@ -241,30 +235,20 @@ class DataprocDeleteClusterTrigger(BaseTrigger):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account.
:param polling_interval: Time in seconds to sleep between checks of cluster status
:param polling_interval_seconds: Time in seconds to sleep between checks of cluster status
"""

def __init__(
self,
cluster_name: str,
end_time: float,
project_id: str | None = None,
region: str | None = None,
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
polling_interval: float = 5.0,
**kwargs: Any,
):
super().__init__(**kwargs)
self.cluster_name = cluster_name
self.end_time = end_time
self.project_id = project_id
self.region = region
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.polling_interval = polling_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes DataprocDeleteClusterTrigger arguments and classpath."""
Expand All @@ -278,16 +262,15 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"metadata": self.metadata,
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"polling_interval": self.polling_interval,
"polling_interval_seconds": self.polling_interval_seconds,
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
"""Wait until cluster is deleted completely"""
hook = self._get_hook()
while self.end_time > time.time():
try:
cluster = await hook.get_cluster(
cluster = await self.get_async_hook().get_cluster(
region=self.region, # type: ignore[arg-type]
cluster_name=self.cluster_name,
project_id=self.project_id, # type: ignore[arg-type]
Expand All @@ -296,52 +279,26 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:
self.log.info(
"Cluster status is %s. Sleeping for %s seconds.",
cluster.status.state,
self.polling_interval,
self.polling_interval_seconds,
)
await asyncio.sleep(self.polling_interval)
await asyncio.sleep(self.polling_interval_seconds)
except NotFound:
yield TriggerEvent({"status": "success", "message": ""})
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
yield TriggerEvent({"status": "error", "message": "Timeout"})

def _get_hook(self) -> DataprocAsyncHook:
return DataprocAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)


class DataprocWorkflowTrigger(BaseTrigger):
class DataprocWorkflowTrigger(DataprocBaseTrigger):
"""
Trigger that periodically polls information from Dataproc API to verify status.
Implementation leverages asynchronous transport.
"""

def __init__(
self,
template_name: str,
name: str,
region: str,
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
delegate_to: str | None = None,
polling_interval_seconds: int = 10,
):
super().__init__()
self.gcp_conn_id = gcp_conn_id
def __init__(self, template_name: str, name: str, **kwargs: Any):
super().__init__(**kwargs)
self.template_name = template_name
self.name = name
self.impersonation_chain = impersonation_chain
self.project_id = project_id
self.region = region
self.polling_interval_seconds = polling_interval_seconds
self.delegate_to = delegate_to
if delegate_to:
warnings.warn(
"'delegate_to' parameter is deprecated, please use 'impersonation_chain'", DeprecationWarning
)

def serialize(self):
return (
Expand All @@ -359,7 +316,7 @@ def serialize(self):
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
hook = self._get_hook()
hook = self.get_async_hook()
while True:
try:
operation = await hook.get_operation(region=self.region, operation_name=self.name)
Expand Down Expand Up @@ -394,9 +351,3 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:
"message": str(e),
}
)

def _get_hook(self) -> DataprocAsyncHook: # type: ignore[override]
return DataprocAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
4 changes: 2 additions & 2 deletions tests/providers/google/cloud/triggers/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def test_async_cluster_trigger_serialization_should_execute_successfully(self, w
}

@pytest.mark.asyncio
@async_mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocWorkflowTrigger._get_hook")
@async_mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger.get_async_hook")
async def test_async_workflow_triggers_on_success_should_execute_successfully(
self, mock_hook, workflow_trigger, async_get_operation
):
Expand All @@ -322,7 +322,7 @@ async def test_async_workflow_triggers_on_success_should_execute_successfully(
assert expected_event == actual_event

@pytest.mark.asyncio
@async_mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocWorkflowTrigger._get_hook")
@async_mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger.get_async_hook")
async def test_async_workflow_triggers_on_error(self, mock_hook, workflow_trigger, async_get_operation):
mock_hook.return_value.get_operation.return_value = async_get_operation(
name=TEST_OPERATION_NAME, done=True, response={}, error=Status(message="test_error")
Expand Down

0 comments on commit 6ef5ba9

Please sign in to comment.
  翻译: