Skip to content

Commit

Permalink
[misc] Get rid of pass statement in conditions (#27775)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored Dec 3, 2022
1 parent 0930d16 commit 4a3a429
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 61 deletions.
14 changes: 9 additions & 5 deletions airflow/providers/amazon/aws/hooks/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,15 @@ def get_job_description(self, job_id: str) -> dict:
return self.parse_job_description(job_id, response)

except botocore.exceptions.ClientError as err:
error = err.response.get("Error", {})
if error.get("Code") == "TooManyRequestsException":
pass # allow it to retry, if possible
else:
raise AirflowException(f"AWS Batch job ({job_id}) description error: {err}")
# Allow it to retry in case of exceeded quota limit of requests to AWS API
if err.response.get("Error", {}).get("Code") != "TooManyRequestsException":
raise
self.log.warning(
"Ignored TooManyRequestsException error, original message: %r. "
"Please consider to setup retries mode in boto3, "
"check Amazon Provider AWS Connection documentation for more details.",
str(err),
)

retries += 1
if retries >= self.status_retries:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/hooks/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,9 @@ def acknowledge(
:param metadata: (Optional) Additional metadata that is provided to the method.
"""
if ack_ids is not None and messages is None:
pass
pass # use ack_ids as is
elif ack_ids is None and messages is not None:
ack_ids = [message.ack_id for message in messages]
ack_ids = [message.ack_id for message in messages] # extract ack_ids from messages
else:
raise ValueError("One and only one of 'ack_ids' and 'messages' arguments have to be provided")

Expand Down
10 changes: 4 additions & 6 deletions airflow/providers/imap/hooks/imap.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,12 @@ def download_mail_attachments(
self._create_files(mail_attachments, local_output_directory)

def _handle_not_found_mode(self, not_found_mode: str) -> None:
if not_found_mode == "raise":
if not_found_mode not in ("raise", "warn", "ignore"):
self.log.error('Invalid "not_found_mode" %s', not_found_mode)
elif not_found_mode == "raise":
raise AirflowException("No mail attachments found!")
if not_found_mode == "warn":
elif not_found_mode == "warn":
self.log.warning("No mail attachments found!")
elif not_found_mode == "ignore":
pass # Do not notify if the attachment has not been found.
else:
self.log.error('Invalid "not_found_mode" %s', not_found_mode)

def _retrieve_mails_attachments_by_name(
self, name: str, check_regex: bool, latest_only: bool, mail_folder: str, mail_filter: str
Expand Down
8 changes: 2 additions & 6 deletions airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,12 +386,8 @@ def get_file_by_pattern(self, path, fnmatch_pattern) -> str:
:param fnmatch_pattern: The pattern that will be matched with `fnmatch`
:return: string containing the first found file, or an empty string if none matched
"""
files_list = self.list_directory(path)

for file in files_list:
if not fnmatch(file, fnmatch_pattern):
pass
else:
for file in self.list_directory(path):
if fnmatch(file, fnmatch_pattern):
return file

return ""
19 changes: 9 additions & 10 deletions airflow/providers/ssh/hooks/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,17 +290,16 @@ def get_conn(self) -> paramiko.SSHClient:
known_hosts = os.path.expanduser("~/.ssh/known_hosts")
if not self.allow_host_key_change and os.path.isfile(known_hosts):
client.load_host_keys(known_hosts)
else:
if self.host_key is not None:
client_host_keys = client.get_host_keys()
if self.port == SSH_PORT:
client_host_keys.add(self.remote_host, self.host_key.get_name(), self.host_key)
else:
client_host_keys.add(
f"[{self.remote_host}]:{self.port}", self.host_key.get_name(), self.host_key
)

elif self.host_key is not None:
# Get host key from connection extra if it not set or None then we fallback to system host keys
client_host_keys = client.get_host_keys()
if self.port == SSH_PORT:
client_host_keys.add(self.remote_host, self.host_key.get_name(), self.host_key)
else:
pass # will fallback to system host keys if none explicitly specified in conn extra
client_host_keys.add(
f"[{self.remote_host}]:{self.port}", self.host_key.get_name(), self.host_key
)

connect_kwargs: dict[str, Any] = dict(
hostname=self.remote_host,
Expand Down
46 changes: 27 additions & 19 deletions tests/providers/amazon/aws/hooks/test_batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
# under the License.
from __future__ import annotations

import unittest
import logging
from unittest import mock

import botocore.exceptions
import pytest
from parameterized import parameterized

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
Expand All @@ -36,7 +35,7 @@
LOG_STREAM_NAME = "test/stream/d56a66bb98a14c4593defa1548686edf"


class TestBatchClient(unittest.TestCase):
class TestBatchClient:

MAX_RETRIES = 2
STATUS_RETRIES = 3
Expand All @@ -45,7 +44,7 @@ class TestBatchClient(unittest.TestCase):
@mock.patch.dict("os.environ", AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID)
@mock.patch.dict("os.environ", AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY)
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
def setUp(self, get_client_type_mock):
def setup_method(self, method, get_client_type_mock):
self.get_client_type_mock = get_client_type_mock
self.batch_client = BatchClientHook(
max_retries=self.MAX_RETRIES,
Expand Down Expand Up @@ -135,13 +134,17 @@ def test_poll_job_complete_raises_for_max_retries(self):
self.client_mock.describe_jobs.assert_called_with(jobs=[JOB_ID])
assert self.client_mock.describe_jobs.call_count == self.MAX_RETRIES + 1

def test_poll_job_status_hit_api_throttle(self):
def test_poll_job_status_hit_api_throttle(self, caplog):
self.client_mock.describe_jobs.side_effect = botocore.exceptions.ClientError(
error_response={"Error": {"Code": "TooManyRequestsException"}},
operation_name="get job description",
)
with pytest.raises(AirflowException) as ctx:
self.batch_client.poll_for_job_complete(JOB_ID)
with caplog.at_level(level=logging.getLevelName("WARNING")):
self.batch_client.poll_for_job_complete(JOB_ID)
log_record = caplog.records[0]
assert "Ignored TooManyRequestsException error" in log_record.message

msg = f"AWS Batch job ({JOB_ID}) description error"
assert msg in str(ctx.value)
# It should retry when this client error occurs
Expand All @@ -153,10 +156,10 @@ def test_poll_job_status_with_client_error(self):
error_response={"Error": {"Code": "InvalidClientTokenId"}},
operation_name="get job description",
)
with pytest.raises(AirflowException) as ctx:
with pytest.raises(botocore.exceptions.ClientError) as ctx:
self.batch_client.poll_for_job_complete(JOB_ID)
msg = f"AWS Batch job ({JOB_ID}) description error"
assert msg in str(ctx.value)

assert ctx.value.response["Error"]["Code"] == "InvalidClientTokenId"
# It will not retry when this client error occurs
self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID])

Expand Down Expand Up @@ -272,7 +275,7 @@ def test_job_awslogs_user_defined(self):
assert awslogs["awslogs_group"] == "/test/batch/job"
assert awslogs["awslogs_region"] == "ap-southeast-2"

def test_job_no_awslogs_stream(self):
def test_job_no_awslogs_stream(self, caplog):
self.client_mock.describe_jobs.return_value = {
"jobs": [
{
Expand All @@ -281,11 +284,13 @@ def test_job_no_awslogs_stream(self):
}
]
}
with self.assertLogs(level="WARNING") as capture_logs:
with caplog.at_level(level=logging.getLevelName("WARNING")):
assert self.batch_client.get_job_awslogs_info(JOB_ID) is None
assert len(capture_logs.records) == 1
assert len(caplog.records) == 1
log_record = caplog.records[0]
assert "doesn't create AWS CloudWatch Stream" in log_record.message

def test_job_splunk_logs(self):
def test_job_splunk_logs(self, caplog):
self.client_mock.describe_jobs.return_value = {
"jobs": [
{
Expand All @@ -299,16 +304,18 @@ def test_job_splunk_logs(self):
}
]
}
with self.assertLogs(level="WARNING") as capture_logs:
with caplog.at_level(level=logging.getLevelName("WARNING")):
assert self.batch_client.get_job_awslogs_info(JOB_ID) is None
assert len(capture_logs.records) == 1
assert len(caplog.records) == 1
log_record = caplog.records[0]
assert "uses logDriver (splunk). AWS CloudWatch logging disabled." in log_record.message


class TestBatchClientDelays(unittest.TestCase):
class TestBatchClientDelays:
@mock.patch.dict("os.environ", AWS_DEFAULT_REGION=AWS_REGION)
@mock.patch.dict("os.environ", AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID)
@mock.patch.dict("os.environ", AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY)
def setUp(self):
def setup_method(self, method):
self.batch_client = BatchClientHook(aws_conn_id="airflow_test", region_name=AWS_REGION)
# We're mocking all actual AWS calls and don't need a connection. This
# avoids an Airflow warning about connection cannot be found.
Expand Down Expand Up @@ -360,7 +367,8 @@ def test_delay_with_float(self, mock_sleep, mock_uniform):
mock_uniform.assert_called_once_with(4.0, 6.0) # in add_jitter
mock_sleep.assert_called_once_with(mock_uniform.return_value)

@parameterized.expand(
@pytest.mark.parametrize(
"tries, lower, upper",
[
(0, 0, 1),
(1, 0, 2),
Expand All @@ -373,7 +381,7 @@ def test_delay_with_float(self, mock_sleep, mock_uniform):
(8, 8, 25),
(9, 10, 31),
(45, 200, 600), # > 40 tries invokes maximum delay allowed
]
],
)
def test_exponential_delay(self, tries, lower, upper):
result = self.batch_client.exponential_delay(tries)
Expand Down
42 changes: 29 additions & 13 deletions tests/providers/google/cloud/hooks/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,24 @@ def mock_init(
pass


def _generate_messages(count) -> list[ReceivedMessage]:
return [
ReceivedMessage(
ack_id=str(i),
message={
"data": f"Message {i}".encode(),
"attributes": {"type": "generated message"},
},
)
for i in range(1, count + 1)
]


class TestPubSubHook(unittest.TestCase):
def setUp(self):
with mock.patch(BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_init):
self.pubsub_hook = PubSubHook(gcp_conn_id="test")

def _generate_messages(self, count) -> list[ReceivedMessage]:
return [
ReceivedMessage(
ack_id=str(i),
message={
"data": f"Message {i}".encode(),
"attributes": {"type": "generated message"},
},
)
for i in range(1, count + 1)
]

@mock.patch("airflow.providers.google.cloud.hooks.pubsub.PubSubHook.get_credentials")
@mock.patch("airflow.providers.google.cloud.hooks.pubsub.PublisherClient")
def test_publisher_client_creation(self, mock_client, mock_get_creds):
Expand Down Expand Up @@ -478,7 +479,7 @@ def test_acknowledge_by_message_objects(self, mock_service):
self.pubsub_hook.acknowledge(
project_id=TEST_PROJECT,
subscription=TEST_SUBSCRIPTION,
messages=self._generate_messages(3),
messages=_generate_messages(3),
)
ack_method.assert_called_once_with(
request=dict(
Expand All @@ -490,6 +491,21 @@ def test_acknowledge_by_message_objects(self, mock_service):
metadata=(),
)

@parameterized.expand([(None, None), ([1, 2, 3], _generate_messages(3))])
@mock.patch(PUBSUB_STRING.format("PubSubHook.subscriber_client"))
def test_acknowledge_fails_on_method_args_validation(self, ack_ids, messages, mock_service):
ack_method = mock_service.acknowledge

error_message = r"One and only one of 'ack_ids' and 'messages' arguments have to be provided"
with pytest.raises(ValueError, match=error_message):
self.pubsub_hook.acknowledge(
project_id=TEST_PROJECT,
subscription=TEST_SUBSCRIPTION,
ack_ids=ack_ids,
messages=messages,
)
ack_method.assert_not_called()

@parameterized.expand(
[
(exception,)
Expand Down

0 comments on commit 4a3a429

Please sign in to comment.
  翻译: