Skip to content

Commit

Permalink
Deferrable mode for CreateBatchPredictionJobOperator (#37818)
Browse files Browse the repository at this point in the history
  • Loading branch information
moiseenkov committed Mar 6, 2024
1 parent 46ee631 commit ec220a8
Show file tree
Hide file tree
Showing 13 changed files with 989 additions and 122 deletions.
257 changes: 254 additions & 3 deletions airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from __future__ import annotations

import asyncio
from functools import lru_cache
from typing import TYPE_CHECKING, Sequence

from google.api_core.client_options import ClientOptions
Expand All @@ -35,7 +34,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand Down Expand Up @@ -431,9 +430,11 @@ def delete_hyperparameter_tuning_job(
return result


class HyperparameterTuningJobAsyncHook(GoogleBaseHook):
class HyperparameterTuningJobAsyncHook(GoogleBaseAsyncHook):
"""Async hook for Google Cloud Vertex AI Hyperparameter Tuning Job APIs."""

sync_hook_class = HyperparameterTuningJobHook

def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
Expand All @@ -446,16 +447,15 @@ def __init__(
**kwargs,
)

@lru_cache
def get_job_service_client(self, region: str | None = None) -> JobServiceAsyncClient:
async def get_job_service_client(self, region: str | None = None) -> JobServiceAsyncClient:
"""
Retrieve Vertex AI async client.
:return: Google Cloud Vertex AI client object.
"""
endpoint = f"{region}-aiplatform.googleapis.com:443" if region and region != "global" else None
return JobServiceAsyncClient(
credentials=self.get_credentials(),
credentials=(await self.get_sync_hook()).get_credentials(),
client_info=CLIENT_INFO,
client_options=ClientOptions(api_endpoint=endpoint),
)
Expand All @@ -479,7 +479,7 @@ async def get_hyperparameter_tuning_job(
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
client: JobServiceAsyncClient = self.get_job_service_client(region=location)
client: JobServiceAsyncClient = await self.get_job_service_client(region=location)
job_name = client.hyperparameter_tuning_job_path(project_id, location, job_id)

result = await client.get_hyperparameter_tuning_job(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,27 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Sequence
import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from google.api_core.exceptions import NotFound
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.aiplatform_v1.types import BatchPredictionJob

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.vertex_ai.batch_prediction_job import BatchPredictionJobHook
from airflow.providers.google.cloud.links.vertex_ai import (
VertexAIBatchPredictionJobLink,
VertexAIBatchPredictionJobListLink,
)
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.cloud.triggers.vertex_ai import CreateBatchPredictionJobTrigger

if TYPE_CHECKING:
from google.api_core.retry import Retry
from google.cloud.aiplatform import Model, explain
from google.cloud.aiplatform import BatchPredictionJob as BatchPredictionJobObject, Model, explain

from airflow.utils.context import Context

Expand Down Expand Up @@ -131,7 +136,7 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
If this is set, then all resources created by the BatchPredictionJob will be encrypted with the
provided encryption key.
Overrides encryption_spec_key_name set in aiplatform.init.
:param sync: Whether to execute this method synchronously. If False, this method will be executed in
:param sync: (Deprecated) Whether to execute this method synchronously. If False, this method will be executed in
concurrent Future and any downstream object will be immediately returned and synced when the
Future has completed.
:param create_request_timeout: Optional. The timeout for the create request in seconds.
Expand All @@ -154,6 +159,8 @@ class CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Optional. Run operator in the deferrable mode.
:param poll_interval: Interval size which defines how often job status is checked in deferrable mode.
"""

template_fields = ("region", "project_id", "model_name", "impersonation_chain")
Expand Down Expand Up @@ -188,6 +195,8 @@ def __init__(
batch_size: int | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
poll_interval: int = 10,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -217,15 +226,24 @@ def __init__(
self.batch_size = batch_size
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.hook: BatchPredictionJobHook | None = None
self.deferrable = deferrable
self.poll_interval = poll_interval

def execute(self, context: Context):
self.log.info("Creating Batch prediction job")
self.hook = BatchPredictionJobHook(
@cached_property
def hook(self) -> BatchPredictionJobHook:
return BatchPredictionJobHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
result = self.hook.create_batch_prediction_job(

def execute(self, context: Context):
warnings.warn(
"The 'sync' parameter is deprecated and will be removed after 28.08.2024.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
self.log.info("Creating Batch prediction job")
batch_prediction_job: BatchPredictionJobObject = self.hook.submit_batch_prediction_job(
region=self.region,
project_id=self.project_id,
job_display_name=self.job_display_name,
Expand All @@ -247,26 +265,62 @@ def execute(self, context: Context):
explanation_parameters=self.explanation_parameters,
labels=self.labels,
encryption_spec_key_name=self.encryption_spec_key_name,
sync=self.sync,
create_request_timeout=self.create_request_timeout,
batch_size=self.batch_size,
)

batch_prediction_job = result.to_dict()
batch_prediction_job_id = self.hook.extract_batch_prediction_job_id(batch_prediction_job)
batch_prediction_job.wait_for_resource_creation()
batch_prediction_job_id = batch_prediction_job.name
self.log.info("Batch prediction job was created. Job id: %s", batch_prediction_job_id)

self.xcom_push(context, key="batch_prediction_job_id", value=batch_prediction_job_id)
VertexAIBatchPredictionJobLink.persist(
context=context, task_instance=self, batch_prediction_job_id=batch_prediction_job_id
)
return batch_prediction_job

if self.deferrable:
self.defer(
trigger=CreateBatchPredictionJobTrigger(
conn_id=self.gcp_conn_id,
project_id=self.project_id,
location=self.region,
job_id=batch_prediction_job.name,
poll_interval=self.poll_interval,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
)

batch_prediction_job.wait_for_completion()
self.log.info("Batch prediction job was completed. Job id: %s", batch_prediction_job_id)
return batch_prediction_job.to_dict()

def on_kill(self) -> None:
"""Act as a callback called when the operator is killed; cancel any running job."""
if self.hook:
self.hook.cancel_batch_prediction_job()

def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any]:
if event and event["status"] == "error":
raise AirflowException(event["message"])
job: dict[str, Any] = event["job"]
self.log.info("Batch prediction job %s created and completed successfully.", job["name"])
job_id = self.hook.extract_batch_prediction_job_id(job)
self.xcom_push(
context,
key="batch_prediction_job_id",
value=job_id,
)
self.xcom_push(
context,
key="training_conf",
value={
"training_conf_id": job_id,
"region": self.region,
"project_id": self.project_id,
},
)
return event["job"]


class DeleteBatchPredictionJobOperator(GoogleCloudBaseOperator):
"""
Expand Down
91 changes: 76 additions & 15 deletions airflow/providers/google/cloud/triggers/vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,37 @@
# under the License.
from __future__ import annotations

from typing import Any, AsyncIterator, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence

from google.cloud.aiplatform_v1 import HyperparameterTuningJob, JobState
from google.cloud.aiplatform_v1 import BatchPredictionJob, HyperparameterTuningJob, JobState, types

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.vertex_ai.batch_prediction_job import BatchPredictionJobAsyncHook
from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import (
HyperparameterTuningJobAsyncHook,
)
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
from proto import Message

class CreateHyperparameterTuningJobTrigger(BaseTrigger):
"""CreateHyperparameterTuningJobTrigger run on the trigger worker to perform create operation."""

class BaseVertexAIJobTrigger(BaseTrigger):
"""Base class for Vertex AI job triggers.
This trigger polls the Vertex AI job and checks its status.
In order to use it properly, you must:
- implement the following methods `_wait_job()`.
- override required `job_type_verbose_name` attribute to provide meaningful message describing your
job type.
- override required `job_serializer_class` attribute to provide proto.Message class that will be used
to serialize your job with `to_dict()` class method.
"""

job_type_verbose_name: str = "Vertex AI Job"
job_serializer_class: Message = None

statuses_success = {
JobState.JOB_STATE_PAUSED,
Expand All @@ -51,10 +69,13 @@ def __init__(
self.job_id = job_id
self.poll_interval = poll_interval
self.impersonation_chain = impersonation_chain
self.trigger_class_path = (
f"airflow.providers.google.cloud.triggers.vertex_ai.{self.__class__.__name__}"
)

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.google.cloud.triggers.vertex_ai.CreateHyperparameterTuningJobTrigger",
self.trigger_class_path,
{
"conn_id": self.conn_id,
"project_id": self.project_id,
Expand All @@ -66,14 +87,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
)

async def run(self) -> AsyncIterator[TriggerEvent]:
hook = self._get_async_hook()
try:
job = await hook.wait_hyperparameter_tuning_job(
project_id=self.project_id,
location=self.location,
job_id=self.job_id,
poll_interval=self.poll_interval,
)
job = await self._wait_job()
except AirflowException as ex:
yield TriggerEvent(
{
Expand All @@ -84,16 +99,62 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
return

status = "success" if job.state in self.statuses_success else "error"
message = f"Hyperparameter tuning job {job.name} completed with status {job.state.name}"
message = f"{self.job_type_verbose_name} {job.name} completed with status {job.state.name}"
yield TriggerEvent(
{
"status": status,
"message": message,
"job": HyperparameterTuningJob.to_dict(job),
"job": self._serialize_job(job),
}
)

def _get_async_hook(self) -> HyperparameterTuningJobAsyncHook:
async def _wait_job(self) -> Any:
"""Awaits a Vertex AI job instance for a status examination."""
raise NotImplementedError

def _serialize_job(self, job: Any) -> Any:
return self.job_serializer_class.to_dict(job)


class CreateHyperparameterTuningJobTrigger(BaseVertexAIJobTrigger):
"""CreateHyperparameterTuningJobTrigger run on the trigger worker to perform create operation."""

job_type_verbose_name = "Hyperparameter Tuning Job"
job_serializer_class = HyperparameterTuningJob

@cached_property
def async_hook(self) -> HyperparameterTuningJobAsyncHook:
return HyperparameterTuningJobAsyncHook(
gcp_conn_id=self.conn_id, impersonation_chain=self.impersonation_chain
)

async def _wait_job(self) -> types.HyperparameterTuningJob:
job: types.HyperparameterTuningJob = await self.async_hook.wait_hyperparameter_tuning_job(
project_id=self.project_id,
location=self.location,
job_id=self.job_id,
poll_interval=self.poll_interval,
)
return job


class CreateBatchPredictionJobTrigger(BaseVertexAIJobTrigger):
"""CreateBatchPredictionJobTrigger run on the trigger worker to perform create operation."""

job_type_verbose_name = "Batch Prediction Job"
job_serializer_class = BatchPredictionJob

@cached_property
def async_hook(self) -> BatchPredictionJobAsyncHook:
return BatchPredictionJobAsyncHook(
gcp_conn_id=self.conn_id, impersonation_chain=self.impersonation_chain
)

async def _wait_job(self) -> types.BatchPredictionJob:
job: types.BatchPredictionJob = await self.async_hook.wait_batch_prediction_job(
project_id=self.project_id,
location=self.location,
job_id=self.job_id,
poll_interval=self.poll_interval,
)
return job
2 changes: 1 addition & 1 deletion airflow/providers/google/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ dependencies:
- google-api-python-client>=1.6.0
- google-auth>=1.0.0
- google-auth-httplib2>=0.0.1
- google-cloud-aiplatform>=1.22.1
- google-cloud-aiplatform>=1.42.1
- google-cloud-automl>=2.12.0
- google-cloud-bigquery-datatransfer>=3.13.0
- google-cloud-bigtable>=2.17.0
Expand Down
10 changes: 10 additions & 0 deletions docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,16 @@ The operator returns batch prediction job id in :ref:`XCom <concepts:xcom>` unde
:start-after: [START how_to_cloud_vertex_ai_create_batch_prediction_job_operator]
:end-before: [END how_to_cloud_vertex_ai_create_batch_prediction_job_operator]

The :class:`~airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.CreateBatchPredictionJobOperator`
also provides deferrable mode:

.. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_batch_prediction_job.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_vertex_ai_create_batch_prediction_job_operator_def]
:end-before: [END how_to_cloud_vertex_ai_create_batch_prediction_job_operator_def]


To delete batch prediction job you can use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.DeleteBatchPredictionJobOperator`.

Expand Down
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@
"google-api-python-client>=1.6.0",
"google-auth-httplib2>=0.0.1",
"google-auth>=1.0.0",
"google-cloud-aiplatform>=1.22.1",
"google-cloud-aiplatform>=1.42.1",
"google-cloud-automl>=2.12.0",
"google-cloud-batch>=0.13.0",
"google-cloud-bigquery-datatransfer>=3.13.0",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ google = [ # source: airflow/providers/google/provider.yaml
"google-api-python-client>=1.6.0",
"google-auth-httplib2>=0.0.1",
"google-auth>=1.0.0",
"google-cloud-aiplatform>=1.22.1",
"google-cloud-aiplatform>=1.42.1",
"google-cloud-automl>=2.12.0",
"google-cloud-batch>=0.13.0",
"google-cloud-bigquery-datatransfer>=3.13.0",
Expand Down
Loading

0 comments on commit ec220a8

Please sign in to comment.
  翻译: