Skip to content

Commit

Permalink
Handle multiple connections using exceptions (#32365)
Browse files Browse the repository at this point in the history
  • Loading branch information
VladaZakharova committed Aug 2, 2023
1 parent df74553 commit 0c894db
Show file tree
Hide file tree
Showing 5 changed files with 478 additions and 31 deletions.
79 changes: 51 additions & 28 deletions airflow/providers/google/cloud/hooks/compute_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
# under the License.
from __future__ import annotations

import random
import shlex
import time
from functools import cached_property
from io import StringIO
from typing import Any

from google.api_core.retry import exponential_sleep_generator
from googleapiclient.errors import HttpError
from paramiko.ssh_exception import SSHException

from airflow import AirflowException
from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook
Expand Down Expand Up @@ -82,6 +84,8 @@ class ComputeEngineSSHHook(SSHHook):
keys are managed using instance metadata
:param expire_time: The maximum amount of time in seconds before the private key expires
:param gcp_conn_id: The connection id to use when fetching connection information
:param max_retries: Maximum number of retries the process will try to establish connection to instance.
Could be decreased/increased by user based on the amount of parallel SSH connections to the instance.
"""

conn_name_attr = "gcp_conn_id"
Expand Down Expand Up @@ -109,6 +113,7 @@ def __init__(
use_oslogin: bool = True,
expire_time: int = 300,
cmd_timeout: int | ArgNotSet = NOTSET,
max_retries: int = 10,
**kwargs,
) -> None:
if kwargs.get("delegate_to") is not None:
Expand All @@ -129,6 +134,7 @@ def __init__(
self.expire_time = expire_time
self.gcp_conn_id = gcp_conn_id
self.cmd_timeout = cmd_timeout
self.max_retries = max_retries
self._conn: Any | None = None

@cached_property
Expand Down Expand Up @@ -225,40 +231,59 @@ def get_conn(self) -> paramiko.SSHClient:
hostname = self.hostname

privkey, pubkey = self._generate_ssh_key(self.user)
if self.use_oslogin:
user = self._authorize_os_login(pubkey)
else:
user = self.user
self._authorize_compute_engine_instance_metadata(pubkey)

proxy_command = None
if self.use_iap_tunnel:
proxy_command_args = [
"gcloud",
"compute",
"start-iap-tunnel",
str(self.instance_name),
"22",
"--listen-on-stdin",
f"--project={self.project_id}",
f"--zone={self.zone}",
"--verbosity=warning",
]
proxy_command = " ".join(shlex.quote(arg) for arg in proxy_command_args)

sshclient = self._connect_to_instance(user, hostname, privkey, proxy_command)

max_delay = 10
sshclient = None
for retry in range(self.max_retries + 1):
try:
if self.use_oslogin:
user = self._authorize_os_login(pubkey)
else:
user = self.user
self._authorize_compute_engine_instance_metadata(pubkey)
proxy_command = None
if self.use_iap_tunnel:
proxy_command_args = [
"gcloud",
"compute",
"start-iap-tunnel",
str(self.instance_name),
"22",
"--listen-on-stdin",
f"--project={self.project_id}",
f"--zone={self.zone}",
"--verbosity=warning",
]
proxy_command = " ".join(shlex.quote(arg) for arg in proxy_command_args)
sshclient = self._connect_to_instance(user, hostname, privkey, proxy_command)
break
except (HttpError, AirflowException, SSHException) as exc:
if (isinstance(exc, HttpError) and exc.resp.status == 412) or (
isinstance(exc, AirflowException) and "412 PRECONDITION FAILED" in str(exc)
):
self.log.info("Error occurred when trying to update instance metadata: %s", exc)
elif isinstance(exc, SSHException):
self.log.info("Error occurred when establishing SSH connection using Paramiko: %s", exc)
else:
raise
if retry == self.max_retries:
raise AirflowException("Maximum retries exceeded. Aborting operation.")
delay = random.randint(0, max_delay)
self.log.info(f"Failed establish SSH connection, waiting {delay} seconds to retry...")
time.sleep(delay)
if not sshclient:
raise AirflowException("Unable to establish SSH connection.")
return sshclient

def _connect_to_instance(self, user, hostname, pkey, proxy_command) -> paramiko.SSHClient:
self.log.info("Opening remote connection to host: username=%s, hostname=%s", user, hostname)
max_time_to_wait = 10
for time_to_wait in exponential_sleep_generator(initial=1, maximum=max_time_to_wait):
max_time_to_wait = 5
for time_to_wait in range(max_time_to_wait + 1):
try:
client = _GCloudAuthorizedSSHClient(self._compute_hook)
# Default is RejectPolicy
# No known host checking since we are not storing privatekey
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

client.connect(
hostname=hostname,
username=user,
Expand All @@ -268,8 +293,6 @@ def _connect_to_instance(self, user, hostname, pkey, proxy_command) -> paramiko.
)
return client
except paramiko.SSHException:
# exponential_sleep_generator is an infinite generator, so we need to
# check the end condition.
if time_to_wait == max_time_to_wait:
raise
self.log.info("Failed to connect. Waiting %ds to retry", time_to_wait)
Expand Down
123 changes: 122 additions & 1 deletion tests/providers/google/cloud/hooks/test_compute_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@
from __future__ import annotations

import json
import logging
from unittest import mock

import httplib2
import pytest
from googleapiclient.errors import HttpError
from paramiko.ssh_exception import SSHException

from airflow import AirflowException
from airflow.models import Connection
from airflow.providers.google.cloud.hooks.compute_ssh import ComputeEngineSSHHook

Expand Down Expand Up @@ -99,7 +104,45 @@ def test_get_conn_default_configuration(
]
)

mock_compute_hook.return_value.set_instance_metadata.assert_not_called()
@pytest.mark.parametrize(
"exception_type, error_message",
[(SSHException, r"Error occurred when establishing SSH connection using Paramiko")],
)
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineSSHHook._connect_to_instance")
def test_get_conn_default_configuration_test_exceptions(
self,
mock_connect,
mock_ssh_client,
mock_paramiko,
mock_os_login_hook,
mock_compute_hook,
exception_type,
error_message,
caplog,
):
mock_paramiko.SSHException = Exception
mock_paramiko.RSAKey.generate.return_value.get_name.return_value = "NAME"
mock_paramiko.RSAKey.generate.return_value.get_base64.return_value = "AYZ"

mock_compute_hook.return_value.project_id = TEST_PROJECT_ID
mock_compute_hook.return_value.get_instance_address.return_value = EXTERNAL_IP

mock_os_login_hook.return_value._get_credentials_email.return_value = "test-example@example.org"
mock_os_login_hook.return_value.import_ssh_public_key.return_value.login_profile.posix_accounts = [
mock.MagicMock(username="test-username")
]

hook = ComputeEngineSSHHook(instance_name=TEST_INSTANCE_NAME, zone=TEST_ZONE)
mock_connect.side_effect = [exception_type, mock_ssh_client]

with caplog.at_level(logging.INFO):
hook.get_conn()
assert error_message in caplog.text
assert "Failed establish SSH connection" in caplog.text

@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
Expand Down Expand Up @@ -159,6 +202,49 @@ def test_get_conn_authorize_using_instance_metadata(

mock_os_login_hook.return_value.import_ssh_public_key.assert_not_called()

@pytest.mark.parametrize(
"exception_type, error_message",
[
(
HttpError(resp=httplib2.Response({"status": 412}), content=b"Error content"),
r"Error occurred when trying to update instance metadata",
),
(
AirflowException("412 PRECONDITION FAILED"),
r"Error occurred when trying to update instance metadata",
),
],
)
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient")
def test_get_conn_authorize_using_instance_metadata_test_exception(
self,
mock_ssh_client,
mock_paramiko,
mock_os_login_hook,
mock_compute_hook,
exception_type,
error_message,
caplog,
):
mock_paramiko.SSHException = Exception
mock_paramiko.RSAKey.generate.return_value.get_name.return_value = "NAME"
mock_paramiko.RSAKey.generate.return_value.get_base64.return_value = "AYZ"

mock_compute_hook.return_value.project_id = TEST_PROJECT_ID
mock_compute_hook.return_value.get_instance_address.return_value = EXTERNAL_IP

mock_compute_hook.return_value.get_instance_info.return_value = {"metadata": {}}
mock_compute_hook.return_value.set_instance_metadata.side_effect = [exception_type, None]

hook = ComputeEngineSSHHook(instance_name=TEST_INSTANCE_NAME, zone=TEST_ZONE, use_oslogin=False)
with caplog.at_level(logging.INFO):
hook.get_conn()
assert error_message in caplog.text
assert "Failed establish SSH connection" in caplog.text

@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko")
Expand Down Expand Up @@ -274,6 +360,41 @@ def test_get_conn_iap_tunnel(self, mock_ssh_client, mock_paramiko, mock_os_login
f"--zone={TEST_ZONE} --verbosity=warning"
)

@pytest.mark.parametrize(
"exception_type, error_message",
[(SSHException, r"Error occurred when establishing SSH connection using Paramiko")],
)
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineSSHHook._connect_to_instance")
def test_get_conn_iap_tunnel_test_exception(
self,
mock_connect,
mock_ssh_client,
mock_paramiko,
mock_os_login_hook,
mock_compute_hook,
exception_type,
error_message,
caplog,
):
del mock_os_login_hook
mock_paramiko.SSHException = Exception

mock_compute_hook.return_value.project_id = TEST_PROJECT_ID

hook = ComputeEngineSSHHook(
instance_name=TEST_INSTANCE_NAME, zone=TEST_ZONE, use_oslogin=False, use_iap_tunnel=True
)
mock_connect.side_effect = [exception_type, mock_ssh_client]

with caplog.at_level(logging.INFO):
hook.get_conn()
assert error_message in caplog.text
assert "Failed establish SSH connection" in caplog.text

@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
schedule_interval="@once",
start_date=datetime(2021, 1, 1),
catchup=False,
tags=["example"],
tags=["example", "compute-ssh"],
) as dag:
# [START howto_operator_gce_insert]
gce_instance_insert = ComputeEngineInsertInstanceOperator(
Expand All @@ -95,7 +95,7 @@
project_id=PROJECT_ID,
use_oslogin=False,
use_iap_tunnel=False,
cmd_timeout=100,
cmd_timeout=1,
),
command="echo metadata_without_iap_tunnel1",
)
Expand Down
Loading

0 comments on commit 0c894db

Please sign in to comment.
  翻译: