Skip to content

Commit

Permalink
Check for same task instead of Equality to detect Duplicate Tasks (#8828
Browse files Browse the repository at this point in the history
)
  • Loading branch information
kaxil authored May 16, 2020
1 parent f4edd90 commit 15273f0
Show file tree
Hide file tree
Showing 17 changed files with 78 additions and 82 deletions.
2 changes: 1 addition & 1 deletion airflow/example_dags/example_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
)

create_tag_template_field_result2 = BashOperator(
task_id="create_tag_template_field_result", bash_command="echo create_tag_template_field_result"
task_id="create_tag_template_field_result2", bash_command="echo create_tag_template_field_result"
)

# Delete
Expand Down
7 changes: 3 additions & 4 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from sqlalchemy.orm import Session

from airflow.configuration import conf
from airflow.exceptions import AirflowException, DuplicateTaskIdFound
from airflow.exceptions import AirflowException
from airflow.lineage import apply_lineage, prepare_lineage
from airflow.models.base import Operator
from airflow.models.pool import Pool
Expand Down Expand Up @@ -600,9 +600,8 @@ def dag(self, dag: Any):
"The DAG assigned to {} can not be changed.".format(self))
elif self.task_id not in dag.task_dict:
dag.add_task(self)
elif self.task_id in dag.task_dict and dag.task_dict[self.task_id] != self:
raise DuplicateTaskIdFound(
"Task id '{}' has already been added to the DAG".format(self.task_id))
elif self.task_id in dag.task_dict and dag.task_dict[self.task_id] is not self:
dag.add_task(self)

self._dag = dag # pylint: disable=attribute-defined-outside-init

Expand Down
2 changes: 1 addition & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,7 +1337,7 @@ def add_task(self, task):
elif task.end_date and self.end_date:
task.end_date = min(task.end_date, self.end_date)

if task.task_id in self.task_dict and self.task_dict[task.task_id] != task:
if task.task_id in self.task_dict and self.task_dict[task.task_id] is not task:
raise DuplicateTaskIdFound(
"Task id '{}' has already been added to the DAG".format(task.task_id))
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@

# [START howto_operator_gcp_datacatalog_create_tag_template_field_result2]
create_tag_template_field_result2 = BashOperator(
task_id="create_tag_template_field_result",
task_id="create_tag_template_field_result2",
bash_command="echo \"{{ task_instance.xcom_pull('create_tag_template_field') }}\"",
)
# [END howto_operator_gcp_datacatalog_create_tag_template_field_result2]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@

# [START howto_operator_gcs_to_gcs_delimiter]
copy_files_with_delimiter = GCSToGCSOperator(
task_id="copy_files_with_wildcard",
task_id="copy_files_with_delimiter",
source_bucket=BUCKET_1_SRC,
source_object="data/",
destination_bucket=BUCKET_1_DST,
Expand Down
18 changes: 1 addition & 17 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,34 +979,18 @@ def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self):

self.assertEqual(dag.task_dict, {op1.task_id: op1})

# Also verify that DAGs with duplicate task_ids don't raise errors
with DAG("test_dag_1", start_date=DEFAULT_DATE) as dag1:
op3 = DummyOperator(task_id="t3")
op4 = BashOperator(task_id="t4", bash_command="sleep 1")
op3 >> op4

self.assertEqual(dag1.task_dict, {op3.task_id: op3, op4.task_id: op4})

def test_duplicate_task_ids_not_allowed_without_dag_context_manager(self):
"""Verify tasks with Duplicate task_id raises error"""
with self.assertRaisesRegex(
DuplicateTaskIdFound, "Task id 't1' has already been added to the DAG"
):
dag = DAG("test_dag", start_date=DEFAULT_DATE)
op1 = DummyOperator(task_id="t1", dag=dag)
op2 = BashOperator(task_id="t1", bash_command="sleep 1", dag=dag)
op2 = DummyOperator(task_id="t1", dag=dag)
op1 >> op2

self.assertEqual(dag.task_dict, {op1.task_id: op1})

# Also verify that DAGs with duplicate task_ids don't raise errors
dag1 = DAG("test_dag_1", start_date=DEFAULT_DATE)
op3 = DummyOperator(task_id="t3", dag=dag1)
op4 = DummyOperator(task_id="t4", dag=dag1)
op3 >> op4

self.assertEqual(dag1.task_dict, {op3.task_id: op3, op4.task_id: op4})

def test_duplicate_task_ids_for_same_task_is_allowed(self):
"""Verify that same tasks with Duplicate task_id do not raise error"""
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
Expand Down
16 changes: 8 additions & 8 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,20 +373,20 @@ def test_ti_updates_with_task(self, session=None):
"""
test that updating the executor_config propogates to the TaskInstance DB
"""
dag = models.DAG(dag_id='test_run_pooling_task')
task = DummyOperator(task_id='test_run_pooling_task_op', dag=dag, owner='airflow',
executor_config={'foo': 'bar'},
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
with models.DAG(dag_id='test_run_pooling_task') as dag:
task = DummyOperator(task_id='test_run_pooling_task_op', owner='airflow',
executor_config={'foo': 'bar'},
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
ti = TI(
task=task, execution_date=timezone.utcnow())

ti.run(session=session)
tis = dag.get_task_instances()
self.assertEqual({'foo': 'bar'}, tis[0].executor_config)

task2 = DummyOperator(task_id='test_run_pooling_task_op', dag=dag, owner='airflow',
executor_config={'bar': 'baz'},
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
with models.DAG(dag_id='test_run_pooling_task') as dag:
task2 = DummyOperator(task_id='test_run_pooling_task_op', owner='airflow',
executor_config={'bar': 'baz'},
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))

ti = TI(
task=task2, execution_date=timezone.utcnow())
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/amazon/aws/operators/test_s3_to_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_s3_to_sftp_operation(self):
def delete_remote_resource(self):
# check the remote file content
remove_file_task = SSHOperator(
task_id="test_check_file",
task_id="test_rm_file",
ssh_hook=self.hook,
command="rm {0}".format(self.sftp_path),
do_xcom_push=True,
Expand Down
42 changes: 24 additions & 18 deletions tests/providers/google/cloud/operators/test_mlengine_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,18 @@ def test_successful_run(self):
self.assertEqual('err=0.9', result)

def test_failures(self):
dag = DAG(
'test_dag',
default_args={
'owner': 'airflow',
'start_date': DEFAULT_DATE,
'end_date': DEFAULT_DATE,
'project_id': 'test-project',
'region': 'us-east1',
},
schedule_interval='@daily')
def create_test_dag(dag_id):
dag = DAG(
dag_id,
default_args={
'owner': 'airflow',
'start_date': DEFAULT_DATE,
'end_date': DEFAULT_DATE,
'project_id': 'test-project',
'region': 'us-east1',
},
schedule_interval='@daily')
return dag

input_with_model = self.INPUT_MISSING_ORIGIN.copy()
other_params_but_models = {
Expand All @@ -151,26 +153,30 @@ def test_failures(self):
'prediction_path': input_with_model['outputPath'],
'metric_fn_and_keys': (self.metric_fn, ['err']),
'validate_fn': (lambda x: 'err=%.1f' % x['err']),
'dag': dag,
}

with self.assertRaisesRegex(AirflowException, 'Missing model origin'):
mlengine_operator_utils.create_evaluate_ops(**other_params_but_models)
mlengine_operator_utils.create_evaluate_ops(
dag=create_test_dag('test_dag_1'), **other_params_but_models)

with self.assertRaisesRegex(AirflowException, 'Ambiguous model origin'):
mlengine_operator_utils.create_evaluate_ops(model_uri='abc', model_name='cde',
**other_params_but_models)
mlengine_operator_utils.create_evaluate_ops(
dag=create_test_dag('test_dag_2'), model_uri='abc', model_name='cde',
**other_params_but_models)

with self.assertRaisesRegex(AirflowException, 'Ambiguous model origin'):
mlengine_operator_utils.create_evaluate_ops(model_uri='abc', version_name='vvv',
**other_params_but_models)
mlengine_operator_utils.create_evaluate_ops(
dag=create_test_dag('test_dag_3'), model_uri='abc', version_name='vvv',
**other_params_but_models)

with self.assertRaisesRegex(AirflowException, '`metric_fn` param must be callable'):
params = other_params_but_models.copy()
params['metric_fn_and_keys'] = (None, ['abc'])
mlengine_operator_utils.create_evaluate_ops(model_uri='gs://blah', **params)
mlengine_operator_utils.create_evaluate_ops(
dag=create_test_dag('test_dag_4'), model_uri='gs://blah', **params)

with self.assertRaisesRegex(AirflowException, '`validate_fn` param must be callable'):
params = other_params_but_models.copy()
params['validate_fn'] = None
mlengine_operator_utils.create_evaluate_ops(model_uri='gs://blah', **params)
mlengine_operator_utils.create_evaluate_ops(
dag=create_test_dag('test_dag_5'), model_uri='gs://blah', **params)
4 changes: 2 additions & 2 deletions tests/providers/google/cloud/sensors/test_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def setUp(self):
self.dag = dag

self.sensor = GCSUploadSessionCompleteSensor(
task_id='sensor',
task_id='sensor_1',
bucket='test-bucket',
prefix='test-prefix/path',
inactivity_period=12,
Expand All @@ -227,7 +227,7 @@ def test_files_deleted_between_pokes_throw_error(self):
@mock.patch('airflow.providers.google.cloud.sensors.gcs.get_time', mock_time)
def test_files_deleted_between_pokes_allow_delete(self):
self.sensor = GCSUploadSessionCompleteSensor(
task_id='sensor',
task_id='sensor_2',
bucket='test-bucket',
prefix='test-prefix/path',
inactivity_period=12,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def setUp(self):

def test_init(self):
operator = FileToWasbOperator(
task_id='wasb_operator',
task_id='wasb_operator_1',
dag=self.dag,
**self._config
)
Expand All @@ -58,7 +58,7 @@ def test_init(self):
self.assertEqual(operator.retries, self._config['retries'])

operator = FileToWasbOperator(
task_id='wasb_operator',
task_id='wasb_operator_2',
dag=self.dag,
load_options={'timeout': 2},
**self._config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def setUp(self):

def test_init(self):
operator = WasbDeleteBlobOperator(
task_id='wasb_operator',
task_id='wasb_operator_1',
dag=self.dag,
**self._config
)
Expand All @@ -53,7 +53,7 @@ def test_init(self):
self.assertEqual(operator.ignore_if_missing, False)

operator = WasbDeleteBlobOperator(
task_id='wasb_operator',
task_id='wasb_operator_2',
dag=self.dag,
is_prefix=True,
ignore_if_missing=True,
Expand Down
8 changes: 4 additions & 4 deletions tests/providers/microsoft/azure/sensors/test_wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def setUp(self):

def test_init(self):
sensor = WasbBlobSensor(
task_id='wasb_sensor',
task_id='wasb_sensor_1',
dag=self.dag,
**self._config
)
Expand All @@ -54,7 +54,7 @@ def test_init(self):
self.assertEqual(sensor.timeout, self._config['timeout'])

sensor = WasbBlobSensor(
task_id='wasb_sensor',
task_id='wasb_sensor_2',
dag=self.dag,
check_options={'timeout': 2},
**self._config
Expand Down Expand Up @@ -94,7 +94,7 @@ def setUp(self):

def test_init(self):
sensor = WasbPrefixSensor(
task_id='wasb_sensor',
task_id='wasb_sensor_1',
dag=self.dag,
**self._config
)
Expand All @@ -105,7 +105,7 @@ def test_init(self):
self.assertEqual(sensor.timeout, self._config['timeout'])

sensor = WasbPrefixSensor(
task_id='wasb_sensor',
task_id='wasb_sensor_2',
dag=self.dag,
check_options={'timeout': 2},
**self._config
Expand Down
20 changes: 10 additions & 10 deletions tests/providers/sftp/operators/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_pickle_file_transfer_put(self):

# put test file to remote
put_test_task = SFTPOperator(
task_id="test_sftp",
task_id="put_test_task",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
Expand All @@ -89,7 +89,7 @@ def test_pickle_file_transfer_put(self):

# check the remote file content
check_file_task = SSHOperator(
task_id="test_check_file",
task_id="check_file_task",
ssh_hook=self.hook,
command="cat {0}".format(self.test_remote_filepath),
do_xcom_push=True,
Expand All @@ -99,7 +99,7 @@ def test_pickle_file_transfer_put(self):
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
ti3.run()
self.assertEqual(
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
ti3.xcom_pull(task_ids=check_file_task.task_id, key='return_value').strip(),
test_local_file_content)

@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
Expand Down Expand Up @@ -178,7 +178,7 @@ def test_json_file_transfer_put(self):

# put test file to remote
put_test_task = SFTPOperator(
task_id="test_sftp",
task_id="put_test_task",
ssh_hook=self.hook,
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
Expand All @@ -191,7 +191,7 @@ def test_json_file_transfer_put(self):

# check the remote file content
check_file_task = SSHOperator(
task_id="test_check_file",
task_id="check_file_task",
ssh_hook=self.hook,
command="cat {0}".format(self.test_remote_filepath),
do_xcom_push=True,
Expand All @@ -201,7 +201,7 @@ def test_json_file_transfer_put(self):
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
ti3.run()
self.assertEqual(
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
ti3.xcom_pull(task_ids=check_file_task.task_id, key='return_value').strip(),
b64encode(test_local_file_content).decode('utf-8'))

@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
Expand Down Expand Up @@ -362,7 +362,7 @@ def test_arg_checking(self):
with self.assertRaisesRegex(AirflowException,
"Cannot operate without ssh_hook or ssh_conn_id."):
task_0 = SFTPOperator(
task_id="test_sftp",
task_id="test_sftp_0",
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.PUT,
Expand All @@ -372,7 +372,7 @@ def test_arg_checking(self):

# if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook
task_1 = SFTPOperator(
task_id="test_sftp",
task_id="test_sftp_1",
ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook
ssh_conn_id=TEST_CONN_ID,
local_filepath=self.test_local_filepath,
Expand All @@ -387,7 +387,7 @@ def test_arg_checking(self):
self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID)

task_2 = SFTPOperator(
task_id="test_sftp",
task_id="test_sftp_2",
ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided
local_filepath=self.test_local_filepath,
remote_filepath=self.test_remote_filepath,
Expand All @@ -402,7 +402,7 @@ def test_arg_checking(self):

# if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id
task_3 = SFTPOperator(
task_id="test_sftp",
task_id="test_sftp_3",
ssh_hook=self.hook,
ssh_conn_id=TEST_CONN_ID,
local_filepath=self.test_local_filepath,
Expand Down
Loading

0 comments on commit 15273f0

Please sign in to comment.
  翻译: