Skip to content

Commit

Permalink
Add support for impersonation in GCP hooks (#9915)
Browse files Browse the repository at this point in the history
Co-authored-by: Kamil Olszewski <kamil.olszewski@polidea.com>
  • Loading branch information
olchas and Kamil Olszewski authored Jul 21, 2020
1 parent 4eddce2 commit 5eacc16
Show file tree
Hide file tree
Showing 49 changed files with 791 additions and 161 deletions.
13 changes: 10 additions & 3 deletions airflow/providers/google/cloud/hooks/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,16 @@ class CloudAutoMLHook(GoogleBaseHook):
"""

def __init__(
self, gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None
):
super().__init__(gcp_conn_id, delegate_to)
self,
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
) -> None:
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
impersonation_chain=impersonation_chain,
)
self._client = None # type: Optional[AutoMlClient]

@staticmethod
Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
def __init__(self,
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
use_legacy_sql: bool = True,
location: Optional[str] = None,
bigquery_conn_id: Optional[str] = None,
Expand All @@ -73,7 +74,10 @@ def __init__(self,
"the gcp_conn_id parameter.", DeprecationWarning, stacklevel=2)
gcp_conn_id = bigquery_conn_id
super().__init__(
gcp_conn_id=gcp_conn_id, delegate_to=delegate_to)
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
impersonation_chain=impersonation_chain,
)
self.use_legacy_sql = use_legacy_sql
self.location = location
self.running_job_id = None # type: Optional[str]
Expand Down
11 changes: 9 additions & 2 deletions airflow/providers/google/cloud/hooks/bigquery_dts.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,16 @@ class BiqQueryDataTransferServiceHook(GoogleBaseHook):
_conn = None # type: Optional[Resource]

def __init__(
self, gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None
self,
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
) -> None:
super().__init__(gcp_conn_id=gcp_conn_id, delegate_to=delegate_to)
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
impersonation_chain=impersonation_chain,
)

@staticmethod
def _disable_auto_scheduling(config: Union[dict, TransferConfig]) -> TransferConfig:
Expand Down
15 changes: 12 additions & 3 deletions airflow/providers/google/cloud/hooks/bigtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""
This module contains a Google Cloud Bigtable Hook.
"""
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Sequence, Union

from google.cloud.bigtable import Client
from google.cloud.bigtable.cluster import Cluster
Expand All @@ -39,8 +39,17 @@ class BigtableHook(GoogleBaseHook):
"""

# pylint: disable=too-many-arguments
def __init__(self, gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None) -> None:
super().__init__(gcp_conn_id, delegate_to)
def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
) -> None:
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
impersonation_chain=impersonation_chain,
)
self._client = None

def _get_client(self, project_id: str):
Expand Down
25 changes: 20 additions & 5 deletions airflow/providers/google/cloud/hooks/cloud_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Hook for Google Cloud Build service"""

import time
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Sequence, Union

from googleapiclient.discovery import build

Expand All @@ -41,10 +41,19 @@ class CloudBuildHook(GoogleBaseHook):
:type api_version: str
:param gcp_conn_id: The connection ID to use when fetching connection info.
:type gcp_conn_id: str
:param delegate_to: The account to impersonate, if any.
For this to work, the service account making the request must have
:param delegate_to: The account to impersonate using domain-wide delegation of authority,
if any. For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account.
:type impersonation_chain: Union[str, Sequence[str]]
"""

_conn = None # type: Optional[Any]
Expand All @@ -53,9 +62,15 @@ def __init__(
self,
api_version: str = "v1",
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
) -> None:
super().__init__(gcp_conn_id, delegate_to)
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
impersonation_chain=impersonation_chain,
)

self.api_version = api_version

def get_conn(self):
Expand Down
26 changes: 22 additions & 4 deletions airflow/providers/google/cloud/hooks/cloud_memorystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,32 @@ class CloudMemorystoreHook(GoogleBaseHook):
:param gcp_conn_id: The connection ID to use when fetching connection info.
:type gcp_conn_id: str
:param delegate_to: The account to impersonate, if any.
For this to work, the service account making the request must have
:param delegate_to: The account to impersonate using domain-wide delegation of authority,
if any. For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account.
:type impersonation_chain: Union[str, Sequence[str]]
"""

def __init__(self, gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None):
super().__init__(gcp_conn_id, delegate_to)
def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
) -> None:
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
impersonation_chain=impersonation_chain,
)
self._client = None # type: Optional[CloudRedisClient]

def get_conn(self,):
Expand Down
13 changes: 9 additions & 4 deletions airflow/providers/google/cloud/hooks/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import time
import uuid
from subprocess import PIPE, Popen
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union
from urllib.parse import quote_plus

import requests
Expand Down Expand Up @@ -80,10 +80,15 @@ class CloudSQLHook(GoogleBaseHook):
def __init__(
self,
api_version: str,
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
) -> None:
super().__init__(gcp_conn_id, delegate_to)
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
impersonation_chain=impersonation_chain,
)
self.api_version = api_version
self._conn = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import warnings
from copy import deepcopy
from datetime import timedelta
from typing import Dict, List, Optional, Set, Union
from typing import Dict, List, Optional, Sequence, Set, Union

from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
Expand Down Expand Up @@ -129,10 +129,15 @@ class CloudDataTransferServiceHook(GoogleBaseHook):
def __init__(
self,
api_version: str = 'v1',
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
) -> None:
super().__init__(gcp_conn_id, delegate_to)
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
impersonation_chain=impersonation_chain,
)
self.api_version = api_version
self._conn = None

Expand Down
11 changes: 8 additions & 3 deletions airflow/providers/google/cloud/hooks/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""

import time
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Sequence, Union

from googleapiclient.discovery import build

Expand Down Expand Up @@ -54,9 +54,14 @@ def __init__(
self,
api_version: str = 'v1',
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
) -> None:
super().__init__(gcp_conn_id, delegate_to)
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
impersonation_chain=impersonation_chain,
)
self.api_version = api_version

def get_conn(self):
Expand Down
26 changes: 22 additions & 4 deletions airflow/providers/google/cloud/hooks/datacatalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,32 @@ class CloudDataCatalogHook(GoogleBaseHook):
:param gcp_conn_id: The connection ID to use when fetching connection info.
:type gcp_conn_id: str
:param delegate_to: The account to impersonate, if any.
For this to work, the service account making the request must have
:param delegate_to: The account to impersonate using domain-wide delegation of authority,
if any. For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account.
:type impersonation_chain: Union[str, Sequence[str]]
"""

def __init__(self, gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None) -> None:
super().__init__(gcp_conn_id, delegate_to)
def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
) -> None:
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
impersonation_chain=impersonation_chain,
)
self._client: Optional[DataCatalogClient] = None

def get_conn(self) -> DataCatalogClient:
Expand Down
11 changes: 8 additions & 3 deletions airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import warnings
from copy import deepcopy
from tempfile import TemporaryDirectory
from typing import Any, Callable, Dict, List, Optional, TypeVar, cast
from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union, cast

from googleapiclient.discovery import build

Expand Down Expand Up @@ -413,12 +413,17 @@ class DataflowHook(GoogleBaseHook):

def __init__(
self,
gcp_conn_id: str = 'google_cloud_default',
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
poll_sleep: int = 10
) -> None:
self.poll_sleep = poll_sleep
super().__init__(gcp_conn_id, delegate_to)
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
impersonation_chain=impersonation_chain,
)

def get_conn(self):
"""
Expand Down
9 changes: 7 additions & 2 deletions airflow/providers/google/cloud/hooks/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import json
import os
from time import monotonic, sleep
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union
from urllib.parse import quote, urlencode

import google.auth
Expand Down Expand Up @@ -64,8 +64,13 @@ def __init__(
api_version: str = "v1beta1",
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
) -> None:
super().__init__(gcp_conn_id, delegate_to)
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
impersonation_chain=impersonation_chain,
)
self.api_version = api_version

def wait_for_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
11 changes: 8 additions & 3 deletions airflow/providers/google/cloud/hooks/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import time
import warnings
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union

from googleapiclient.discovery import build

Expand All @@ -42,8 +42,9 @@ class DatastoreHook(GoogleBaseHook):

def __init__(
self,
gcp_conn_id: str = 'google_cloud_default',
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
api_version: str = 'v1',
datastore_conn_id: Optional[str] = None
) -> None:
Expand All @@ -52,7 +53,11 @@ def __init__(
"The datastore_conn_id parameter has been deprecated. You should pass "
"the gcp_conn_id parameter.", DeprecationWarning, stacklevel=2)
gcp_conn_id = datastore_conn_id
super().__init__(gcp_conn_id=gcp_conn_id, delegate_to=delegate_to)
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
impersonation_chain=impersonation_chain,
)
self.connection = None
self.api_version = api_version

Expand Down
Loading

0 comments on commit 5eacc16

Please sign in to comment.
  翻译: