Skip to content

Commit

Permalink
Implement deferrable mode for BeamRunJavaPipelineOperator (#36122)
Browse files Browse the repository at this point in the history
  • Loading branch information
moiseenkov committed Dec 19, 2023
1 parent 4758185 commit 881d88b
Show file tree
Hide file tree
Showing 10 changed files with 793 additions and 182 deletions.
19 changes: 19 additions & 0 deletions airflow/providers/apache/beam/hooks/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,25 @@ async def start_python_pipeline_async(
)
return return_code

async def start_java_pipeline_async(self, variables: dict, jar: str, job_class: str | None = None):
"""
Start Apache Beam Java pipeline.
:param variables: Variables passed to the job.
:param jar: Name of the jar for the pipeline.
:param job_class: Name of the java class for the pipeline.
:return: Beam command execution return code.
"""
if "labels" in variables:
variables["labels"] = json.dumps(variables["labels"], separators=(",", ":"))

command_prefix = ["java", "-cp", jar, job_class] if job_class else ["java", "-jar", jar]
return_code = await self.start_pipeline_async(
variables=variables,
command_prefix=command_prefix,
)
return return_code

async def start_pipeline_async(
self,
variables: dict,
Expand Down
163 changes: 106 additions & 57 deletions airflow/providers/apache/beam/operators/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType
from airflow.providers.apache.beam.triggers.beam import BeamPipelineTrigger
from airflow.providers.apache.beam.triggers.beam import BeamJavaPipelineTrigger, BeamPythonPipelineTrigger
from airflow.providers.google.cloud.hooks.dataflow import (
DataflowHook,
process_line_and_extract_dataflow_job_id_callback,
Expand Down Expand Up @@ -239,6 +239,22 @@ def _init_pipeline_options(
check_job_status_callback,
)

def execute_complete(self, context: Context, event: dict[str, Any]):
"""
Execute when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])
self.log.info(
"%s completed with response %s ",
self.task_id,
event["message"],
)
return {"dataflow_job_id": self.dataflow_job_id}


class BeamRunPythonPipelineOperator(BeamBasePipelineOperator):
"""
Expand Down Expand Up @@ -323,7 +339,7 @@ def __init__(
self.deferrable = deferrable

def execute(self, context: Context):
"""Execute the Apache Beam Pipeline."""
"""Execute the Apache Beam Python Pipeline."""
(
self.is_dataflow,
self.dataflow_job_name,
Expand Down Expand Up @@ -408,7 +424,7 @@ async def execute_async(self, context: Context):
)
with self.dataflow_hook.provide_authorized_gcloud():
self.defer(
trigger=BeamPipelineTrigger(
trigger=BeamPythonPipelineTrigger(
variables=self.snake_case_pipeline_options,
py_file=self.py_file,
py_options=self.py_options,
Expand All @@ -421,7 +437,7 @@ async def execute_async(self, context: Context):
)
else:
self.defer(
trigger=BeamPipelineTrigger(
trigger=BeamPythonPipelineTrigger(
variables=self.snake_case_pipeline_options,
py_file=self.py_file,
py_options=self.py_options,
Expand All @@ -433,22 +449,6 @@ async def execute_async(self, context: Context):
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any]):
"""
Execute when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])
self.log.info(
"%s completed with response %s ",
self.task_id,
event["message"],
)
return {"dataflow_job_id": self.dataflow_job_id}

def on_kill(self) -> None:
if self.dataflow_hook and self.dataflow_job_id:
self.log.info("Dataflow job with id: `%s` was requested to be cancelled.", self.dataflow_job_id)
Expand Down Expand Up @@ -509,6 +509,7 @@ def __init__(
pipeline_options: dict | None = None,
gcp_conn_id: str = "google_cloud_default",
dataflow_config: DataflowConfiguration | dict | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(
Expand All @@ -521,61 +522,55 @@ def __init__(
)
self.jar = jar
self.job_class = job_class
self.deferrable = deferrable

def execute(self, context: Context):
"""Execute the Apache Beam Pipeline."""
"""Execute the Apache Beam Python Pipeline."""
(
is_dataflow,
dataflow_job_name,
pipeline_options,
process_line_callback,
self.is_dataflow,
self.dataflow_job_name,
self.pipeline_options,
self.process_line_callback,
_,
) = self._init_pipeline_options()

if not self.beam_hook:
raise AirflowException("Beam hook is not defined.")
if self.deferrable:
asyncio.run(self.execute_async(context))
else:
return self.execute_sync(context)

def execute_sync(self, context: Context):
"""Execute the Apache Beam Pipeline."""
with ExitStack() as exit_stack:
if self.jar.lower().startswith("gs://"):
gcs_hook = GCSHook(self.gcp_conn_id)
tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.jar))
self.jar = tmp_gcs_file.name

if is_dataflow and self.dataflow_hook:
is_running = False
if self.dataflow_config.check_if_running != CheckJobRunning.IgnoreJob:
is_running = (
# The reason for disable=no-value-for-parameter is that project_id parameter is
# required but here is not passed, moreover it cannot be passed here.
# This method is wrapped by @_fallback_to_project_id_from_variables decorator which
# fallback project_id value from variables and raise error if project_id is
# defined both in variables and as parameter (here is already defined in variables)
self.dataflow_hook.is_job_dataflow_running(
name=self.dataflow_config.job_name,
variables=pipeline_options,
)
if self.is_dataflow and self.dataflow_hook:
is_running = self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun
while is_running and self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun:
# The reason for disable=no-value-for-parameter is that project_id parameter is
# required but here is not passed, moreover it cannot be passed here.
# This method is wrapped by @_fallback_to_project_id_from_variables decorator which
# fallback project_id value from variables and raise error if project_id is
# defined both in variables and as parameter (here is already defined in variables)
is_running = self.dataflow_hook.is_job_dataflow_running(
name=self.dataflow_config.job_name,
variables=self.pipeline_options,
)
while is_running and self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun:
# The reason for disable=no-value-for-parameter is that project_id parameter is
# required but here is not passed, moreover it cannot be passed here.
# This method is wrapped by @_fallback_to_project_id_from_variables decorator which
# fallback project_id value from variables and raise error if project_id is
# defined both in variables and as parameter (here is already defined in variables)

is_running = self.dataflow_hook.is_job_dataflow_running(
name=self.dataflow_config.job_name,
variables=pipeline_options,
)

if not is_running:
pipeline_options["jobName"] = dataflow_job_name
self.pipeline_options["jobName"] = self.dataflow_job_name
with self.dataflow_hook.provide_authorized_gcloud():
self.beam_hook.start_java_pipeline(
variables=pipeline_options,
variables=self.pipeline_options,
jar=self.jar,
job_class=self.job_class,
process_line_callback=process_line_callback,
process_line_callback=self.process_line_callback,
)
if dataflow_job_name and self.dataflow_config.location:
if self.dataflow_job_name and self.dataflow_config.location:
multiple_jobs = self.dataflow_config.multiple_jobs or False
DataflowJobLink.persist(
self,
Expand All @@ -585,7 +580,7 @@ def execute(self, context: Context):
self.dataflow_job_id,
)
self.dataflow_hook.wait_for_done(
job_name=dataflow_job_name,
job_name=self.dataflow_job_name,
location=self.dataflow_config.location,
job_id=self.dataflow_job_id,
multiple_jobs=multiple_jobs,
Expand All @@ -594,11 +589,65 @@ def execute(self, context: Context):
return {"dataflow_job_id": self.dataflow_job_id}
else:
self.beam_hook.start_java_pipeline(
variables=pipeline_options,
variables=self.pipeline_options,
jar=self.jar,
job_class=self.job_class,
process_line_callback=process_line_callback,
process_line_callback=self.process_line_callback,
)

async def execute_async(self, context: Context):
# Creating a new event loop to manage I/O operations asynchronously
loop = asyncio.get_event_loop()
if self.jar.lower().startswith("gs://"):
gcs_hook = GCSHook(self.gcp_conn_id)
# Running synchronous `enter_context()` method in a separate
# thread using the default executor `None`. The `run_in_executor()` function returns the
# file object, which is created using gcs function `provide_file()`, asynchronously.
# This means we can perform asynchronous operations with this file.
create_tmp_file_call = gcs_hook.provide_file(object_url=self.jar)
tmp_gcs_file: IO[str] = await loop.run_in_executor(
None, contextlib.ExitStack().enter_context, create_tmp_file_call
)
self.jar = tmp_gcs_file.name

if self.is_dataflow and self.dataflow_hook:
DataflowJobLink.persist(
self,
context,
self.dataflow_config.project_id,
self.dataflow_config.location,
self.dataflow_job_id,
)
with self.dataflow_hook.provide_authorized_gcloud():
self.pipeline_options["jobName"] = self.dataflow_job_name
self.defer(
trigger=BeamJavaPipelineTrigger(
variables=self.pipeline_options,
jar=self.jar,
job_class=self.job_class,
runner=self.runner,
check_if_running=self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun,
project_id=self.dataflow_config.project_id,
location=self.dataflow_config.location,
job_name=self.dataflow_job_name,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.dataflow_config.impersonation_chain,
poll_sleep=self.dataflow_config.poll_sleep,
cancel_timeout=self.dataflow_config.cancel_timeout,
),
method_name="execute_complete",
)
else:
self.defer(
trigger=BeamJavaPipelineTrigger(
variables=self.pipeline_options,
jar=self.jar,
job_class=self.job_class,
runner=self.runner,
check_if_running=self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun,
),
method_name="execute_complete",
)

def on_kill(self) -> None:
if self.dataflow_hook and self.dataflow_job_id:
Expand Down
Loading

0 comments on commit 881d88b

Please sign in to comment.
  翻译: