Skip to content

Commit

Permalink
Avoid to use functools.lru_cache in class methods in google provi…
Browse files Browse the repository at this point in the history
…der (#38652)
  • Loading branch information
Taragolis authored Apr 1, 2024
1 parent 1cac59e commit d3dc88f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 24 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/compute_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def _authorize_compute_engine_instance_metadata(self, pubkey):
)

def _authorize_os_login(self, pubkey):
username = self._oslogin_hook._get_credentials_email()
username = self._oslogin_hook._get_credentials_email
self.log.info("Importing SSH public key using OSLogin: user=%s", username)
expiration = int((time.time() + self.expire_time) * 1000000)
ssh_public_key = {"key": pubkey, "expiration_time_usec": expiration}
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/common/hooks/base_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def _get_access_token(self) -> str:
credentials.refresh(auth_req)
return credentials.token

@functools.lru_cache(maxsize=None)
@functools.cached_property
def _get_credentials_email(self) -> str:
"""
Return the email address associated with the currently logged in account.
Expand Down
56 changes: 34 additions & 22 deletions tests/providers/google/cloud/hooks/test_compute_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.google.cloud.hooks.compute_ssh import ComputeEngineSSHHook
from airflow.providers.google.cloud.hooks.os_login import OSLoginHook

pytestmark = pytest.mark.db_test

Expand All @@ -48,22 +49,35 @@ def test_delegate_to_runtime_error(self):
with pytest.raises(RuntimeError):
ComputeEngineSSHHook(gcp_conn_id="gcpssh", delegate_to="delegate_to")

def test_os_login_hook(self, mocker):
mock_os_login_hook = mocker.patch.object(OSLoginHook, "__init__", return_value=None, spec=OSLoginHook)

# Default values
assert ComputeEngineSSHHook()._oslogin_hook
mock_os_login_hook.assert_called_with(gcp_conn_id="google_cloud_default")

# Custom conn_id
assert ComputeEngineSSHHook(gcp_conn_id="gcpssh")._oslogin_hook
mock_os_login_hook.assert_called_with(gcp_conn_id="gcpssh")

@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient")
def test_get_conn_default_configuration(
self, mock_ssh_client, mock_paramiko, mock_os_login_hook, mock_compute_hook
):
mock_paramiko.SSHException = Exception
def test_get_conn_default_configuration(self, mock_ssh_client, mock_paramiko, mock_compute_hook, mocker):
mock_paramiko.SSHException = RuntimeError
mock_paramiko.RSAKey.generate.return_value.get_name.return_value = "NAME"
mock_paramiko.RSAKey.generate.return_value.get_base64.return_value = "AYZ"

mock_compute_hook.return_value.project_id = TEST_PROJECT_ID
mock_compute_hook.return_value.get_instance_address.return_value = EXTERNAL_IP

mock_os_login_hook.return_value._get_credentials_email.return_value = "test-example@example.org"
mock_os_login_hook.return_value.import_ssh_public_key.return_value.login_profile.posix_accounts = [
mock_os_login_hook = mocker.patch.object(
ComputeEngineSSHHook, "_oslogin_hook", spec=OSLoginHook, name="FakeOsLoginHook"
)
type(mock_os_login_hook)._get_credentials_email = mock.PropertyMock(
return_value="test-example@example.org"
)
mock_os_login_hook.import_ssh_public_key.return_value.login_profile.posix_accounts = [
mock.MagicMock(username="test-username")
]

Expand All @@ -83,16 +97,10 @@ def test_get_conn_default_configuration(
),
]
)
mock_os_login_hook.assert_has_calls(
[
mock.call(gcp_conn_id="google_cloud_default"),
mock.call()._get_credentials_email(),
mock.call().import_ssh_public_key(
ssh_public_key={"key": "NAME AYZ root", "expiration_time_usec": mock.ANY},
project_id="test-project-id",
user=mock_os_login_hook.return_value._get_credentials_email.return_value,
),
]
mock_os_login_hook.import_ssh_public_key.assert_called_once_with(
ssh_public_key={"key": "NAME AYZ root", "expiration_time_usec": mock.ANY},
project_id="test-project-id",
user="test-example@example.org",
)
mock_ssh_client.assert_has_calls(
[
Expand All @@ -113,7 +121,6 @@ def test_get_conn_default_configuration(
[(SSHException, r"Error occurred when establishing SSH connection using Paramiko")],
)
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineSSHHook._connect_to_instance")
Expand All @@ -122,21 +129,26 @@ def test_get_conn_default_configuration_test_exceptions(
mock_connect,
mock_ssh_client,
mock_paramiko,
mock_os_login_hook,
mock_compute_hook,
exception_type,
error_message,
caplog,
mocker,
):
mock_paramiko.SSHException = Exception
mock_paramiko.SSHException = RuntimeError
mock_paramiko.RSAKey.generate.return_value.get_name.return_value = "NAME"
mock_paramiko.RSAKey.generate.return_value.get_base64.return_value = "AYZ"

mock_compute_hook.return_value.project_id = TEST_PROJECT_ID
mock_compute_hook.return_value.get_instance_address.return_value = EXTERNAL_IP

mock_os_login_hook.return_value._get_credentials_email.return_value = "test-example@example.org"
mock_os_login_hook.return_value.import_ssh_public_key.return_value.login_profile.posix_accounts = [
mock_os_login_hook = mocker.patch.object(
ComputeEngineSSHHook, "_oslogin_hook", spec=OSLoginHook, name="FakeOsLoginHook"
)
type(mock_os_login_hook)._get_credentials_email = mock.PropertyMock(
return_value="test-example@example.org"
)
mock_os_login_hook.import_ssh_public_key.return_value.login_profile.posix_accounts = [
mock.MagicMock(username="test-username")
]

Expand Down

0 comments on commit d3dc88f

Please sign in to comment.
  翻译: