Skip to content

Commit

Permalink
Rename DeleteCustomTrainingJobOperator's fields' names to comply wi…
Browse files Browse the repository at this point in the history
…th templated fields validation (#38048)

Co-authored-by: Andrey Anshin <Andrey.Anshin@taragol.is>

Rename `DeleteCustomTrainingJobOperator`'s fields' name to comply with templated fields validation
  • Loading branch information
shahar1 authored Mar 15, 2024
1 parent 60b95c7 commit 83060e1
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 5 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,6 @@ repos:
exclude: |
(?x)^(
^airflow\/providers\/google\/cloud\/operators\/mlengine.py$|
^airflow\/providers\/google\/cloud\/operators\/vertex_ai\/custom_job.py$|
^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service.py$|
^airflow\/providers\/apache\/spark\/operators\/spark_submit.py\.py$|
^airflow\/providers\/google\/cloud\/operators\/vertex_ai\/auto_ml\.py$|
Expand Down
28 changes: 25 additions & 3 deletions airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@

from typing import TYPE_CHECKING, Sequence

from deprecated import deprecated
from google.api_core.exceptions import NotFound
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.aiplatform.models import Model
from google.cloud.aiplatform_v1.types.dataset import Dataset
from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook
from airflow.providers.google.cloud.links.vertex_ai import (
VertexAIModelLink,
Expand Down Expand Up @@ -1328,7 +1330,7 @@ class DeleteCustomTrainingJobOperator(GoogleCloudBaseOperator):
account from the list granting this role to the originating account (templated).
"""

template_fields = ("training_pipeline", "custom_job", "region", "project_id", "impersonation_chain")
template_fields = ("training_pipeline_id", "custom_job_id", "region", "project_id", "impersonation_chain")

def __init__(
self,
Expand All @@ -1345,8 +1347,8 @@ def __init__(
**kwargs,
) -> None:
super().__init__(**kwargs)
self.training_pipeline = training_pipeline_id
self.custom_job = custom_job_id
self.training_pipeline_id = training_pipeline_id
self.custom_job_id = custom_job_id
self.region = region
self.project_id = project_id
self.retry = retry
Expand All @@ -1355,6 +1357,26 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

@property
@deprecated(
reason="`training_pipeline` is deprecated and will be removed in the future. "
"Please use `training_pipeline_id` instead.",
category=AirflowProviderDeprecationWarning,
)
def training_pipeline(self):
"""Alias for ``training_pipeline_id``, used for compatibility (deprecated)."""
return self.training_pipeline_id

@property
@deprecated(
reason="`custom_job` is deprecated and will be removed in the future. "
"Please use `custom_job_id` instead.",
category=AirflowProviderDeprecationWarning,
)
def custom_job(self):
"""Alias for ``custom_job_id``, used for compatibility (deprecated)."""
return self.custom_job_id

def execute(self, context: Context):
hook = CustomJobHook(
gcp_conn_id=self.gcp_conn_id,
Expand Down
32 changes: 31 additions & 1 deletion tests/providers/google/cloud/operators/test_vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from google.api_core.gapic_v1.method import DEFAULT
from google.api_core.retry import Retry

from airflow.exceptions import AirflowException, TaskDeferred
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred
from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
CreateAutoMLForecastingTrainingJobOperator,
CreateAutoMLImageTrainingJobOperator,
Expand Down Expand Up @@ -84,6 +84,7 @@
ListPipelineJobOperator,
RunPipelineJobOperator,
)
from airflow.utils import timezone

VERTEX_AI_PATH = "airflow.providers.google.cloud.operators.vertex_ai.{}"
VERTEX_AI_LINKS_PATH = "airflow.providers.google.cloud.links.vertex_ai.{}"
Expand Down Expand Up @@ -477,6 +478,35 @@ def test_execute(self, mock_hook):
metadata=METADATA,
)

@pytest.mark.db_test
def test_templating(self, create_task_instance_of_operator):
ti = create_task_instance_of_operator(
DeleteCustomTrainingJobOperator,
# Templated fields
training_pipeline_id="{{ 'training-pipeline-id' }}",
custom_job_id="{{ 'custom_job_id' }}",
region="{{ 'region' }}",
project_id="{{ 'project_id' }}",
impersonation_chain="{{ 'impersonation-chain' }}",
# Other parameters
dag_id="test_template_body_templating_dag",
task_id="test_template_body_templating_task",
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
)
ti.render_templates()
task: DeleteCustomTrainingJobOperator = ti.task
assert task.training_pipeline_id == "training-pipeline-id"
assert task.custom_job_id == "custom_job_id"
assert task.region == "region"
assert task.project_id == "project_id"
assert task.impersonation_chain == "impersonation-chain"

# Deprecated aliases
with pytest.warns(AirflowProviderDeprecationWarning):
assert task.training_pipeline == "training-pipeline-id"
with pytest.warns(AirflowProviderDeprecationWarning):
assert task.custom_job == "custom_job_id"


class TestVertexAIListCustomTrainingJobOperator:
@mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
Expand Down

0 comments on commit 83060e1

Please sign in to comment.
  翻译: