Skip to content

Commit

Permalink
Update secrets backends to use get_conn_value instead of get_conn_uri (
Browse files Browse the repository at this point in the history
…#22348)

In #19857 we enabled storing connections as JSON instead of URI and renamed get_conn_uri to get_conn_value to be consistent with this change.  The method get_conn_uri is now deprecated and should warn when used.
  • Loading branch information
dstandish authored Mar 24, 2022
1 parent f06b395 commit 7ab45d4
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 32 deletions.
31 changes: 29 additions & 2 deletions airflow/providers/amazon/aws/secrets/secrets_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@

import ast
import json
import re
import sys
import warnings
from typing import Optional
from urllib.parse import urlencode

import boto3

from airflow.version import version as airflow_version

if sys.version_info >= (3, 8):
from functools import cached_property
else:
Expand All @@ -34,6 +38,11 @@
from airflow.utils.log.logging_mixin import LoggingMixin


def _parse_version(val):
val = re.sub(r'(\d+\.\d+\.\d+).*', lambda x: x.group(1), val)
return tuple(int(x) for x in val.split('.'))


class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
"""
Retrieves Connection or Variables from AWS Secrets Manager
Expand Down Expand Up @@ -173,9 +182,9 @@ def get_uri_from_secret(self, secret):

return connection

def get_conn_uri(self, conn_id: str):
def get_conn_value(self, conn_id: str):
"""
Get Connection Value
Get serialized representation of Connection
:param conn_id: connection id
"""
Expand All @@ -199,6 +208,24 @@ def get_conn_uri(self, conn_id: str):

return connection

def get_conn_uri(self, conn_id: str) -> Optional[str]:
"""
Return URI representation of Connection conn_id.
As of Airflow version 2.3.0 this method is deprecated.
:param conn_id: the connection id
:return: deserialized Connection
"""
if _parse_version(airflow_version) >= (2, 3):
warnings.warn(
f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed "
"in a future release. Please use method `get_conn_value` instead.",
DeprecationWarning,
stacklevel=2,
)
return self.get_conn_value(conn_id)

def get_variable(self, key: str) -> Optional[str]:
"""
Get Airflow Variable from Environment Variable
Expand Down
29 changes: 28 additions & 1 deletion airflow/providers/amazon/aws/secrets/systems_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
# specific language governing permissions and limitations
# under the License.
"""Objects relating to sourcing connections from AWS SSM Parameter Store"""
import re
import sys
import warnings
from typing import Optional

import boto3

from airflow.version import version as airflow_version

if sys.version_info >= (3, 8):
from functools import cached_property
else:
Expand All @@ -30,6 +34,11 @@
from airflow.utils.log.logging_mixin import LoggingMixin


def _parse_version(val):
val = re.sub(r'(\d+\.\d+\.\d+).*', lambda x: x.group(1), val)
return tuple(int(x) for x in val.split('.'))


class SystemsManagerParameterStoreBackend(BaseSecretsBackend, LoggingMixin):
"""
Retrieves Connection or Variables from AWS SSM Parameter Store
Expand Down Expand Up @@ -86,7 +95,7 @@ def client(self):
session = boto3.Session(profile_name=self.profile_name)
return session.client("ssm", **self.kwargs)

def get_conn_uri(self, conn_id: str) -> Optional[str]:
def get_conn_value(self, conn_id: str) -> Optional[str]:
"""
Get param value
Expand All @@ -97,6 +106,24 @@ def get_conn_uri(self, conn_id: str) -> Optional[str]:

return self._get_secret(self.connections_prefix, conn_id)

def get_conn_uri(self, conn_id: str) -> Optional[str]:
"""
Return URI representation of Connection conn_id.
As of Airflow version 2.3.0 this method is deprecated.
:param conn_id: the connection id
:return: deserialized Connection
"""
if _parse_version(airflow_version) >= (2, 3):
warnings.warn(
f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed "
"in a future release. Please use method `get_conn_value` instead.",
DeprecationWarning,
stacklevel=2,
)
return self.get_conn_value(conn_id)

def get_variable(self, key: str) -> Optional[str]:
"""
Get Airflow Variable from Environment Variable
Expand Down
30 changes: 28 additions & 2 deletions airflow/providers/google/cloud/secrets/secret_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

"""Objects relating to sourcing connections from Google Cloud Secrets Manager"""
import logging
import re
import warnings
from typing import Optional

from google.auth.exceptions import DefaultCredentialsError
Expand All @@ -26,12 +28,18 @@
from airflow.providers.google.cloud.utils.credentials_provider import get_credentials_and_project_id
from airflow.secrets import BaseSecretsBackend
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.version import version as airflow_version

log = logging.getLogger(__name__)

SECRET_ID_PATTERN = r"^[a-zA-Z0-9-_]*$"


def _parse_version(val):
val = re.sub(r'(\d+\.\d+\.\d+).*', lambda x: x.group(1), val)
return tuple(int(x) for x in val.split('.'))


class CloudSecretManagerBackend(BaseSecretsBackend, LoggingMixin):
"""
Retrieves Connection object from Google Cloud Secrets Manager
Expand Down Expand Up @@ -121,9 +129,9 @@ def _is_valid_prefix_and_sep(self) -> bool:
prefix = self.connections_prefix + self.sep
return _SecretManagerClient.is_valid_secret_name(prefix)

def get_conn_uri(self, conn_id: str) -> Optional[str]:
def get_conn_value(self, conn_id: str) -> Optional[str]:
"""
Get secret value from the SecretManager.
Get serialized representation of Connection
:param conn_id: connection id
"""
Expand All @@ -132,6 +140,24 @@ def get_conn_uri(self, conn_id: str) -> Optional[str]:

return self._get_secret(self.connections_prefix, conn_id)

def get_conn_uri(self, conn_id: str) -> Optional[str]:
"""
Return URI representation of Connection conn_id.
As of Airflow version 2.3.0 this method is deprecated.
:param conn_id: the connection id
:return: deserialized Connection
"""
if _parse_version(airflow_version) >= (2, 3):
warnings.warn(
f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed "
"in a future release. Please use method `get_conn_value` instead.",
DeprecationWarning,
stacklevel=2,
)
return self.get_conn_value(conn_id)

def get_variable(self, key: str) -> Optional[str]:
"""
Get Airflow Variable from Environment Variable
Expand Down
13 changes: 11 additions & 2 deletions airflow/providers/hashicorp/secrets/vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.
"""Objects relating to sourcing connections & variables from Hashicorp Vault"""
import warnings
from typing import TYPE_CHECKING, Optional

from airflow.providers.hashicorp._internal_client.vault_client import _VaultClient
Expand Down Expand Up @@ -168,14 +169,22 @@ def get_response(self, conn_id: str) -> Optional[dict]:

def get_conn_uri(self, conn_id: str) -> Optional[str]:
"""
Get secret value from Vault. Store the secret in the form of URI
Get serialized representation of connection
:param conn_id: The connection id
:rtype: str
:return: The connection uri retrieved from the secret
"""
response = self.get_response(conn_id)

# Since VaultBackend implements `get_connection`, `get_conn_uri` is not used. So we
# don't need to implement (or direct users to use) method `get_conn_value` instead
warnings.warn(
f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed "
"in a future release.",
PendingDeprecationWarning,
stacklevel=2,
)
response = self.get_response(conn_id)
return response.get("conn_uri") if response else None

# Make sure connection is imported this way for type checking, otherwise when importing
Expand Down
31 changes: 29 additions & 2 deletions airflow/providers/microsoft/azure/secrets/key_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
import sys
import warnings
from typing import Optional

from azure.core.exceptions import ResourceNotFoundError
from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient

from airflow.version import version as airflow_version

if sys.version_info >= (3, 8):
from functools import cached_property
else:
Expand All @@ -30,6 +34,11 @@
from airflow.utils.log.logging_mixin import LoggingMixin


def _parse_version(val):
val = re.sub(r'(\d+\.\d+\.\d+).*', lambda x: x.group(1), val)
return tuple(int(x) for x in val.split('.'))


class AzureKeyVaultBackend(BaseSecretsBackend, LoggingMixin):
"""
Retrieves Airflow Connections or Variables from Azure Key Vault secrets.
Expand Down Expand Up @@ -100,9 +109,9 @@ def client(self) -> SecretClient:
client = SecretClient(vault_url=self.vault_url, credential=credential, **self.kwargs)
return client

def get_conn_uri(self, conn_id: str) -> Optional[str]:
def get_conn_value(self, conn_id: str) -> Optional[str]:
"""
Get an Airflow Connection URI from an Azure Key Vault secret
Get a serialized representation of Airflow Connection from an Azure Key Vault secret
:param conn_id: The Airflow connection id to retrieve
"""
Expand All @@ -111,6 +120,24 @@ def get_conn_uri(self, conn_id: str) -> Optional[str]:

return self._get_secret(self.connections_prefix, conn_id)

def get_conn_uri(self, conn_id: str) -> Optional[str]:
"""
Return URI representation of Connection conn_id.
As of Airflow version 2.3.0 this method is deprecated.
:param conn_id: the connection id
:return: deserialized Connection
"""
if _parse_version(airflow_version) >= (2, 3):
warnings.warn(
f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed "
"in a future release. Please use method `get_conn_value` instead.",
DeprecationWarning,
stacklevel=2,
)
return self.get_conn_value(conn_id)

def get_variable(self, key: str) -> Optional[str]:
"""
Get an Airflow Variable from an Azure Key Vault secret.
Expand Down
15 changes: 4 additions & 11 deletions docs/apache-airflow/security/secrets/secrets-backend/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,28 +99,21 @@ Roll your own secrets backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

A secrets backend is a subclass of :py:class:`airflow.secrets.BaseSecretsBackend` and must implement either
:py:meth:`~airflow.secrets.BaseSecretsBackend.get_connection` or :py:meth:`~airflow.secrets.BaseSecretsBackend.get_conn_uri`.
:py:meth:`~airflow.secrets.BaseSecretsBackend.get_connection` or :py:meth:`~airflow.secrets.BaseSecretsBackend.get_conn_value`.

After writing your backend class, provide the fully qualified class name in the ``backend`` key in the ``[secrets]``
section of ``airflow.cfg``.

Additional arguments to your SecretsBackend can be configured in ``airflow.cfg`` by supplying a JSON string to ``backend_kwargs``, which will be passed to the ``__init__`` of your SecretsBackend.
See :ref:`Configuration <secrets_backend_configuration>` for more details, and :ref:`SSM Parameter Store <ssm_parameter_store_secrets>` for an example.

.. note::

If you are rolling your own secrets backend, you don't strictly need to use airflow's URI format. But
doing so makes it easier to switch between environment variables, the metastore, and your secrets backend.

Adapt to non-Airflow compatible secret formats for connections
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The default implementation of Secret backend requires use of an Airflow-specific format of storing
secrets for connections. Currently most community provided implementations require the connections to
be stored as URIs (with the possibility of adding more friendly formats in the future)
:doc:`apache-airflow-providers:core-extensions/secrets-backends`. However some organizations may prefer
to keep the credentials (passwords/tokens etc) in other formats --
for example when you want the same credentials to be used across multiple clients, or when you want to
use built-in mechanism of rotating the credentials that do not work well with the Airflow-specific format.
be stored as JSON or the Airflow Connection URI format (see
:doc:`apache-airflow-providers:core-extensions/secrets-backends`). However some organizations may need to store the credentials (passwords/tokens etc) in some other way, for example if the same credentials store needs to be used for multiple data platforms, or if you are using a service with a built-in mechanism of rotating the credentials that does not work with the Airflow-specific format.
In this case you will need to roll your own secret backend as described in the previous chapter,
possibly extending existing secret backend and adapt it to the scheme used by your organization.
possibly extending an existing secrets backend and adapting it to the scheme used by your organization.
6 changes: 3 additions & 3 deletions tests/providers/amazon/aws/secrets/test_secrets_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@


class TestSecretsManagerBackend(TestCase):
@mock.patch("airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend.get_conn_uri")
def test_aws_secrets_manager_get_connections(self, mock_get_uri):
mock_get_uri.return_value = "scheme://user:pass@host:100"
@mock.patch("airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend.get_conn_value")
def test_aws_secrets_manager_get_connections(self, mock_get_value):
mock_get_value.return_value = "scheme://user:pass@host:100"
conn_list = SecretsManagerBackend().get_connections("fake_conn")
conn = conn_list[0]
assert conn.host == 'host'
Expand Down
6 changes: 3 additions & 3 deletions tests/providers/amazon/aws/secrets/test_systems_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
class TestSsmSecrets(TestCase):
@mock.patch(
"airflow.providers.amazon.aws.secrets.systems_manager."
"SystemsManagerParameterStoreBackend.get_conn_uri"
"SystemsManagerParameterStoreBackend.get_conn_value"
)
def test_aws_ssm_get_connections(self, mock_get_uri):
mock_get_uri.return_value = "scheme://user:pass@host:100"
def test_aws_ssm_get_connections(self, mock_get_value):
mock_get_value.return_value = "scheme://user:pass@host:100"
conn_list = SystemsManagerParameterStoreBackend().get_connections("fake_conn")
conn = conn_list[0]
assert conn.host == 'host'
Expand Down
6 changes: 3 additions & 3 deletions tests/providers/google/cloud/secrets/test_secret_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ def test_get_conn_uri(self, connections_prefix, mock_client_callable, mock_get_c
mock_client.secret_version_path.assert_called_once_with(PROJECT_ID, secret_id, "latest")

@mock.patch(MODULE_NAME + ".get_credentials_and_project_id")
@mock.patch(MODULE_NAME + ".CloudSecretManagerBackend.get_conn_uri")
def test_get_connections(self, mock_get_uri, mock_get_creds):
@mock.patch(MODULE_NAME + ".CloudSecretManagerBackend.get_conn_value")
def test_get_connections(self, mock_get_value, mock_get_creds):
mock_get_creds.return_value = CREDENTIALS, PROJECT_ID
mock_get_uri.return_value = CONN_URI
mock_get_value.return_value = CONN_URI
conns = CloudSecretManagerBackend().get_connections(conn_id=CONN_ID)
assert isinstance(conns, list)
assert isinstance(conns[0], Connection)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@


class TestAzureKeyVaultBackend(TestCase):
@mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.get_conn_uri')
def test_get_connections(self, mock_get_uri):
mock_get_uri.return_value = 'scheme://user:pass@host:100'
@mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.get_conn_value')
def test_get_connections(self, mock_get_value):
mock_get_value.return_value = 'scheme://user:pass@host:100'
conn_list = AzureKeyVaultBackend().get_connections('fake_conn')
conn = conn_list[0]
assert conn.host == 'host'
Expand Down

0 comments on commit 7ab45d4

Please sign in to comment.
  翻译: