Skip to content

Commit

Permalink
Add support for service account impersonation with computeEngineSSHHo…
Browse files Browse the repository at this point in the history
…ok (google provider) and IAP tunnel (#35136)



---------

Co-authored-by: gcazalet <gcazalet@solocal.com>
  • Loading branch information
ginolegigot and gcazalet authored Nov 25, 2023
1 parent c905fe8 commit 770f164
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
18 changes: 15 additions & 3 deletions airflow/providers/google/cloud/hooks/compute_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ class ComputeEngineSSHHook(SSHHook):
:param gcp_conn_id: The connection id to use when fetching connection information
:param max_retries: Maximum number of retries the process will try to establish connection to instance.
Could be decreased/increased by user based on the amount of parallel SSH connections to the instance.
:param impersonation_chain: Optional. The service account email to impersonate using short-term
credentials. The provided service account must grant the originating account
the Service Account Token Creator IAM role and have the sufficient rights to perform the request
"""

conn_name_attr = "gcp_conn_id"
Expand Down Expand Up @@ -114,15 +117,17 @@ def __init__(
expire_time: int = 300,
cmd_timeout: int | ArgNotSet = NOTSET,
max_retries: int = 10,
impersonation_chain: str | None = None,
**kwargs,
) -> None:
if kwargs.get("delegate_to") is not None:
raise RuntimeError(
"The `delegate_to` parameter has been deprecated before and finally removed in this version"
" of Google Provider. You MUST convert it to `impersonate_chain`"
" of Google Provider. You MUST convert it to `impersonation_chain`"
)
# Ignore original constructor
# super().__init__()
self.gcp_conn_id = gcp_conn_id
self.instance_name = instance_name
self.zone = zone
self.user = user
Expand All @@ -132,9 +137,9 @@ def __init__(
self.use_iap_tunnel = use_iap_tunnel
self.use_oslogin = use_oslogin
self.expire_time = expire_time
self.gcp_conn_id = gcp_conn_id
self.cmd_timeout = cmd_timeout
self.max_retries = max_retries
self.impersonation_chain = impersonation_chain
self._conn: Any | None = None

@cached_property
Expand All @@ -143,7 +148,12 @@ def _oslogin_hook(self) -> OSLoginHook:

@cached_property
def _compute_hook(self) -> ComputeEngineHook:
return ComputeEngineHook(gcp_conn_id=self.gcp_conn_id)
if self.impersonation_chain:
return ComputeEngineHook(
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
)
else:
return ComputeEngineHook(gcp_conn_id=self.gcp_conn_id)

def _load_connection_config(self):
def _boolify(value):
Expand Down Expand Up @@ -254,6 +264,8 @@ def get_conn(self) -> paramiko.SSHClient:
f"--zone={self.zone}",
"--verbosity=warning",
]
if self.impersonation_chain:
proxy_command_args.append(f"--impersonate-service-account={self.impersonation_chain}")
proxy_command = " ".join(shlex.quote(arg) for arg in proxy_command_args)
sshclient = self._connect_to_instance(user, hostname, privkey, proxy_command)
break
Expand Down
36 changes: 36 additions & 0 deletions tests/providers/google/cloud/hooks/test_compute_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
EXTERNAL_IP = "192.3.3.3"
TEST_PUB_KEY = "root:NAME AYZ root"
TEST_PUB_KEY2 = "root:NAME MNJ root"
IMPERSONATION_CHAIN = "SERVICE_ACCOUNT"


class TestComputeEngineHookWithPassedProjectId:
Expand Down Expand Up @@ -363,6 +364,41 @@ def test_get_conn_iap_tunnel(self, mock_ssh_client, mock_paramiko, mock_os_login
f"--zone={TEST_ZONE} --verbosity=warning"
)

@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient")
def test_get_conn_iap_tunnel_with_impersonation_chain(
self, mock_ssh_client, mock_paramiko, mock_os_login_hook, mock_compute_hook
):
del mock_os_login_hook
mock_paramiko.SSHException = Exception

mock_compute_hook.return_value.project_id = TEST_PROJECT_ID

hook = ComputeEngineSSHHook(
instance_name=TEST_INSTANCE_NAME,
zone=TEST_ZONE,
use_oslogin=False,
use_iap_tunnel=True,
impersonation_chain=IMPERSONATION_CHAIN,
)
result = hook.get_conn()
assert mock_ssh_client.return_value == result

mock_ssh_client.return_value.connect.assert_called_once_with(
hostname=mock.ANY,
look_for_keys=mock.ANY,
pkey=mock.ANY,
sock=mock_paramiko.ProxyCommand.return_value,
username=mock.ANY,
)
mock_paramiko.ProxyCommand.assert_called_once_with(
f"gcloud compute start-iap-tunnel {TEST_INSTANCE_NAME} 22 "
f"--listen-on-stdin --project={TEST_PROJECT_ID} "
f"--zone={TEST_ZONE} --verbosity=warning --impersonate-service-account={IMPERSONATION_CHAIN}"
)

@pytest.mark.parametrize(
"exception_type, error_message",
[(SSHException, r"Error occurred when establishing SSH connection using Paramiko")],
Expand Down

0 comments on commit 770f164

Please sign in to comment.
  翻译: