Skip to content

Commit

Permalink
Add deferrable mode to GCSObjectUpdateSensor (#30579)
Browse files Browse the repository at this point in the history
  • Loading branch information
phanikumv authored Apr 22, 2023
1 parent 6a89ba3 commit 9e49d91
Show file tree
Hide file tree
Showing 7 changed files with 416 additions and 5 deletions.
39 changes: 37 additions & 2 deletions airflow/providers/google/cloud/sensors/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.cloud.triggers.gcs import GCSBlobTrigger
from airflow.providers.google.cloud.triggers.gcs import GCSBlobTrigger, GCSCheckBlobUpdateTimeTrigger
from airflow.sensors.base import BaseSensorOperator, poke_mode_only

if TYPE_CHECKING:
Expand Down Expand Up @@ -184,6 +184,7 @@ class GCSObjectUpdateSensor(BaseSensorOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run sensor in deferrable mode
"""

template_fields: Sequence[str] = (
Expand All @@ -200,6 +201,7 @@ def __init__(
ts_func: Callable = ts_function,
google_cloud_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = False,
**kwargs,
) -> None:

Expand All @@ -209,6 +211,7 @@ def __init__(
self.ts_func = ts_func
self.google_cloud_conn_id = google_cloud_conn_id
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable

def poke(self, context: Context) -> bool:
self.log.info("Sensor checks existence of : %s, %s", self.bucket, self.object)
Expand All @@ -218,6 +221,38 @@ def poke(self, context: Context) -> bool:
)
return hook.is_updated_after(self.bucket, self.object, self.ts_func(context))

def execute(self, context: Context) -> None:
"""Airflow runs this method on the worker and defers using the trigger."""
if self.deferrable is False:
super().execute(context)
else:
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=GCSCheckBlobUpdateTimeTrigger(
bucket=self.bucket,
object_name=self.object,
target_date=self.ts_func(context),
poke_interval=self.poke_interval,
google_cloud_conn_id=self.google_cloud_conn_id,
hook_params={
"delegate_to": self.delegate_to,
"impersonation_chain": self.impersonation_chain,
},
),
method_name="execute_complete",
)

def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str:
"""Callback for when the trigger fires."""
if event:
if event["status"] == "success":
self.log.info(
"Checking last updated time for object %s in bucket : %s", self.object, self.bucket
)
return event["message"]
raise AirflowException(event["message"])
raise AirflowException("No event received in trigger callback")


class GCSObjectsWithPrefixExistenceSensor(BaseSensorOperator):
"""
Expand Down Expand Up @@ -265,7 +300,7 @@ def __init__(
self.impersonation_chain = impersonation_chain

def poke(self, context: Context) -> bool:
self.log.info("Sensor checks existence of objects: %s, %s", self.bucket, self.prefix)
self.log.info("Checking for existence of object: %s, %s", self.bucket, self.prefix)
hook = GCSHook(
gcp_conn_id=self.google_cloud_conn_id,
impersonation_chain=self.impersonation_chain,
Expand Down
105 changes: 104 additions & 1 deletion airflow/providers/google/cloud/triggers/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from __future__ import annotations

import asyncio
from datetime import datetime
from typing import Any, AsyncIterator

from aiohttp import ClientSession

from airflow.providers.google.cloud.hooks.gcs import GCSAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils import timezone


class GCSBlobTrigger(BaseTrigger):
Expand Down Expand Up @@ -65,7 +67,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
"""Simple loop until the relevant file/folder is found."""
"""loop until the relevant file/folder is found."""
try:
hook = self._get_async_hook()
while True:
Expand Down Expand Up @@ -97,3 +99,104 @@ async def _object_exists(self, hook: GCSAsyncHook, bucket_name: str, object_name
if object_response:
return "success"
return "pending"


class GCSCheckBlobUpdateTimeTrigger(BaseTrigger):
"""
A trigger that makes an async call to GCS to check whether the object is updated in a bucket.
:param bucket: google cloud storage bucket name cloud storage where the objects are residing.
:param object_name: the file or folder present in the bucket
:param target_date: context datetime to compare with blob object updated time
:param poke_interval: polling period in seconds to check for file/folder
:param google_cloud_conn_id: reference to the Google Connection
:param hook_params: dict object that has delegate_to and impersonation_chain
"""

def __init__(
self,
bucket: str,
object_name: str,
target_date: datetime,
poke_interval: float,
google_cloud_conn_id: str,
hook_params: dict[str, Any],
):
super().__init__()
self.bucket = bucket
self.object_name = object_name
self.target_date = target_date
self.poke_interval = poke_interval
self.google_cloud_conn_id: str = google_cloud_conn_id
self.hook_params = hook_params

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes GCSCheckBlobUpdateTimeTrigger arguments and classpath."""
return (
"airflow.providers.google.cloud.triggers.gcs.GCSCheckBlobUpdateTimeTrigger",
{
"bucket": self.bucket,
"object_name": self.object_name,
"target_date": self.target_date,
"poke_interval": self.poke_interval,
"google_cloud_conn_id": self.google_cloud_conn_id,
"hook_params": self.hook_params,
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
"""Loop until the object updated time is greater than target datetime"""
try:
hook = self._get_async_hook()
while True:
status, res = await self._is_blob_updated_after(
hook=hook,
bucket_name=self.bucket,
object_name=self.object_name,
target_date=self.target_date,
)
if status:
yield TriggerEvent(res)
await asyncio.sleep(self.poke_interval)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})

def _get_async_hook(self) -> GCSAsyncHook:
return GCSAsyncHook(gcp_conn_id=self.google_cloud_conn_id, **self.hook_params)

async def _is_blob_updated_after(
self, hook: GCSAsyncHook, bucket_name: str, object_name: str, target_date: datetime
) -> tuple[bool, dict[str, Any]]:
"""
Checks if the object in the bucket is updated.
:param hook: GCSAsyncHook Hook class
:param bucket_name: The Google Cloud Storage bucket where the object is.
:param object_name: The name of the blob_name to check in the Google cloud
storage bucket.
:param target_date: context datetime to compare with blob object updated time
"""
async with ClientSession() as session:
client = await hook.get_storage_client(session)
bucket = client.get_bucket(bucket_name)
blob = await bucket.get_blob(blob_name=object_name)
if blob is None:
res = {
"message": f"Object ({object_name}) not found in Bucket ({bucket_name})",
"status": "error",
}
return True, res

blob_updated_date = blob.updated # type: ignore[attr-defined]
blob_updated_time = datetime.strptime(blob_updated_date, "%Y-%m-%dT%H:%M:%S.%fZ").replace(
tzinfo=timezone.utc
) # Blob updated time is in string format so converting the string format
# to datetime object to compare the last updated time

if blob_updated_time is not None:
if not target_date.tzinfo:
target_date = target_date.replace(tzinfo=timezone.utc)
self.log.info("Verify object date: %s > %s", blob_updated_time, target_date)
if blob_updated_time > target_date:
return True, {"status": "success", "message": "success"}
return False, {"status": "pending", "message": "pending"}
10 changes: 10 additions & 0 deletions docs/apache-airflow-providers-google/operators/cloud/gcs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,16 @@ Use the :class:`~airflow.providers.google.cloud.sensors.gcs.GCSObjectUpdateSenso
:start-after: [START howto_sensor_object_update_exists_task]
:end-before: [END howto_sensor_object_update_exists_task]

You can set the ``deferrable`` param to True if you want this sensor to run asynchronously, leading to efficient
utilization of resources in your Airflow deployment. However the triggerer component needs to be enabled
for this functionality to work.

.. exampleinclude:: /../../tests/system/providers/google/cloud/gcs/example_gcs_sensor.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_object_update_exists_task_async]
:end-before: [END howto_sensor_object_update_exists_task_async]

More information
""""""""""""""""

Expand Down
44 changes: 43 additions & 1 deletion tests/providers/google/cloud/sensors/test_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
GCSUploadSessionCompleteSensor,
ts_function,
)
from airflow.providers.google.cloud.triggers.gcs import GCSBlobTrigger
from airflow.providers.google.cloud.triggers.gcs import GCSBlobTrigger, GCSCheckBlobUpdateTimeTrigger
from tests.providers.google.cloud.utils.airflow_util import create_context

TEST_BUCKET = "TEST_BUCKET"

Expand Down Expand Up @@ -225,6 +226,47 @@ def test_should_pass_argument_to_hook(self, mock_hook):
assert result is True


class TestGCSObjectUpdateSensorAsync:
OPERATOR = GCSObjectUpdateSensor(
task_id="gcs-obj-update",
bucket=TEST_BUCKET,
object=TEST_OBJECT,
google_cloud_conn_id=TEST_GCP_CONN_ID,
deferrable=True,
)

def test_gcs_object_update_sensor_async(self, context):
"""
Asserts that a task is deferred and a GCSBlobTrigger will be fired
when the GCSObjectUpdateSensorAsync is executed.
"""

with pytest.raises(TaskDeferred) as exc:
self.OPERATOR.execute(create_context(self.OPERATOR))
assert isinstance(
exc.value.trigger, GCSCheckBlobUpdateTimeTrigger
), "Trigger is not a GCSCheckBlobUpdateTimeTrigger"

def test_gcs_object_update_sensor_async_execute_failure(self, context):
"""Tests that an AirflowException is raised in case of error event"""

with pytest.raises(AirflowException):
self.OPERATOR.execute_complete(
context=context, event={"status": "error", "message": "test failure message"}
)

def test_gcs_object_update_sensor_async_execute_complete(self, context):
"""Asserts that logging occurs as expected"""

with mock.patch.object(self.OPERATOR.log, "info") as mock_log_info:
self.OPERATOR.execute_complete(
context=context, event={"status": "success", "message": "Job completed"}
)
mock_log_info.assert_called_with(
"Checking last updated time for object %s in bucket : %s", TEST_OBJECT, TEST_BUCKET
)


class TestGoogleCloudStoragePrefixSensor:
@mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook")
def test_should_pass_arguments_to_hook(self, mock_hook):
Expand Down
Loading

0 comments on commit 9e49d91

Please sign in to comment.
  翻译: