Skip to content

Commit

Permalink
Add type annotations for AWS operators and hooks (#11434)
Browse files Browse the repository at this point in the history
Co-authored-by: Tomek Urbaszek <turbaszek@gmail.com>
  • Loading branch information
potix2 and turbaszek authored Oct 16, 2020
1 parent 3c10ca6 commit 0823d46
Show file tree
Hide file tree
Showing 15 changed files with 124 additions and 108 deletions.
3 changes: 2 additions & 1 deletion airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import Any, Dict, Optional, Tuple, Union

import boto3
from botocore.credentials import ReadOnlyCredentials
from botocore.config import Config
from cached_property import cached_property

Expand Down Expand Up @@ -393,7 +394,7 @@ def get_session(self, region_name: Optional[str] = None) -> boto3.session.Sessio
session, _ = self._get_credentials(region_name)
return session

def get_credentials(self, region_name: Optional[str] = None) -> Tuple[Optional[str], Optional[str]]:
def get_credentials(self, region_name: Optional[str] = None) -> ReadOnlyCredentials:
"""
Get the underlying `botocore.Credentials` object.
Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from copy import copy
from os.path import getsize
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, IO
from uuid import uuid4

from airflow.models import BaseOperator
Expand All @@ -34,11 +34,11 @@
from airflow.utils.decorators import apply_defaults


def _convert_item_to_json_bytes(item):
def _convert_item_to_json_bytes(item: Dict[str, Any]) -> bytes:
return (json.dumps(item) + '\n').encode('utf-8')


def _upload_file_to_s3(file_obj, bucket_name, s3_key_prefix):
def _upload_file_to_s3(file_obj: IO, bucket_name: str, s3_key_prefix: str) -> None:
s3_client = S3Hook().get_conn()
file_obj.seek(0)
s3_client.upload_file(
Expand Down Expand Up @@ -102,7 +102,7 @@ def __init__(
s3_key_prefix: str = '',
process_func: Callable[[Dict[str, Any]], bytes] = _convert_item_to_json_bytes,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.file_size = file_size
self.process_func = process_func
Expand All @@ -111,7 +111,7 @@ def __init__(
self.s3_bucket_name = s3_bucket_name
self.s3_key_prefix = s3_key_prefix

def execute(self, context):
def execute(self, context) -> None:
table = AwsDynamoDBHook().get_conn().Table(self.dynamodb_table_name)
scan_kwargs = copy(self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {}
err = None
Expand All @@ -126,7 +126,7 @@ def execute(self, context):
_upload_file_to_s3(f, self.s3_bucket_name, self.s3_key_prefix)
f.close()

def _scan_dynamodb_and_upload_to_s3(self, temp_file, scan_kwargs, table):
def _scan_dynamodb_and_upload_to_s3(self, temp_file: IO, scan_kwargs: dict, table: Any) -> IO:
while True:
response = table.scan(**scan_kwargs)
items = response['Items']
Expand Down
28 changes: 14 additions & 14 deletions airflow/providers/amazon/aws/transfers/gcs_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
This module contains Google Cloud Storage to S3 operator.
"""
import warnings
from typing import Iterable, Optional, Sequence, Union, Dict
from typing import Iterable, Optional, Sequence, Union, Dict, List, cast

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
Expand Down Expand Up @@ -100,21 +100,21 @@ class GCSToS3Operator(BaseOperator):
def __init__(
self,
*, # pylint: disable=too-many-arguments
bucket,
prefix=None,
delimiter=None,
gcp_conn_id='google_cloud_default',
google_cloud_storage_conn_id=None,
delegate_to=None,
dest_aws_conn_id=None,
dest_s3_key=None,
dest_verify=None,
replace=False,
bucket: str,
prefix: Optional[str] = None,
delimiter: Optional[str] = None,
gcp_conn_id: str = 'google_cloud_default',
google_cloud_storage_conn_id: Optional[str] = None,
delegate_to: Optional[str] = None,
dest_aws_conn_id: str = 'aws_default',
dest_s3_key: str,
dest_verify: Optional[Union[str, bool]] = None,
replace: bool = False,
google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
dest_s3_extra_args: Optional[Dict] = None,
s3_acl_policy: Optional[str] = None,
**kwargs,
):
) -> None:
super().__init__(**kwargs)

if google_cloud_storage_conn_id:
Expand All @@ -139,7 +139,7 @@ def __init__(
self.dest_s3_extra_args = dest_s3_extra_args or {}
self.s3_acl_policy = s3_acl_policy

def execute(self, context):
def execute(self, context) -> List[str]:
# list all files in an Google Cloud Storage bucket
hook = GCSHook(
google_cloud_storage_conn_id=self.gcp_conn_id,
Expand Down Expand Up @@ -183,7 +183,7 @@ def execute(self, context):
self.log.info("Saving file to %s", dest_key)

s3_hook.load_bytes(
file_bytes, key=dest_key, replace=self.replace, acl_policy=self.s3_acl_policy
cast(bytes, file_bytes), key=dest_key, replace=self.replace, acl_policy=self.s3_acl_policy
)

self.log.info("All done, uploaded %d files to S3", len(files))
Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/amazon/aws/transfers/glacier_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,17 @@ class GlacierToGCSOperator(BaseOperator):
def __init__(
self,
*,
aws_conn_id="aws_default",
gcp_conn_id="google_cloud_default",
aws_conn_id: str = "aws_default",
gcp_conn_id: str = "google_cloud_default",
vault_name: str,
bucket_name: str,
object_name: str,
gzip: bool,
chunk_size=1024,
delegate_to=None,
chunk_size: int = 1024,
delegate_to: Optional[str] = None,
google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.gcp_conn_id = gcp_conn_id
Expand All @@ -94,7 +94,7 @@ def __init__(
self.delegate_to = delegate_to
self.impersonation_chain = google_impersonation_chain

def execute(self, context):
def execute(self, context) -> str:
glacier_hook = GlacierHook(aws_conn_id=self.aws_conn_id)
gcs_hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
Expand Down
44 changes: 22 additions & 22 deletions airflow/providers/amazon/aws/transfers/google_api_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import sys
from typing import Optional, Sequence, Union

from airflow.models import BaseOperator
from airflow.models.xcom import MAX_XCOM_SIZE
from airflow.models import BaseOperator, TaskInstance
from airflow.models.xcom import MAX_XCOM_SIZE, XCOM_RETURN_KEY
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.google.common.hooks.discovery_api import GoogleDiscoveryApiHook
from airflow.utils.decorators import apply_defaults
Expand Down Expand Up @@ -98,20 +98,20 @@ class GoogleApiToS3Operator(BaseOperator):
def __init__(
self,
*,
google_api_service_name,
google_api_service_version,
google_api_endpoint_path,
google_api_endpoint_params,
s3_destination_key,
google_api_response_via_xcom=None,
google_api_endpoint_params_via_xcom=None,
google_api_endpoint_params_via_xcom_task_ids=None,
google_api_pagination=False,
google_api_num_retries=0,
s3_overwrite=False,
gcp_conn_id='google_cloud_default',
delegate_to=None,
aws_conn_id='aws_default',
google_api_service_name: str,
google_api_service_version: str,
google_api_endpoint_path: str,
google_api_endpoint_params: dict,
s3_destination_key: str,
google_api_response_via_xcom: Optional[str] = None,
google_api_endpoint_params_via_xcom: Optional[str] = None,
google_api_endpoint_params_via_xcom_task_ids: Optional[str] = None,
google_api_pagination: bool = False,
google_api_num_retries: int = 0,
s3_overwrite: bool = False,
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
aws_conn_id: str = 'aws_default',
google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
):
Expand All @@ -132,7 +132,7 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.google_impersonation_chain = google_impersonation_chain

def execute(self, context):
def execute(self, context) -> None:
"""
Transfers Google APIs json data to S3.
Expand All @@ -151,7 +151,7 @@ def execute(self, context):
if self.google_api_response_via_xcom:
self._expose_google_api_response_via_xcom(context['task_instance'], data)

def _retrieve_data_from_google_api(self):
def _retrieve_data_from_google_api(self) -> dict:
google_discovery_api_hook = GoogleDiscoveryApiHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
Expand All @@ -167,21 +167,21 @@ def _retrieve_data_from_google_api(self):
)
return google_api_response

def _load_data_to_s3(self, data):
def _load_data_to_s3(self, data: dict) -> None:
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
s3_hook.load_string(
string_data=json.dumps(data), key=self.s3_destination_key, replace=self.s3_overwrite
)

def _update_google_api_endpoint_params_via_xcom(self, task_instance):
def _update_google_api_endpoint_params_via_xcom(self, task_instance: TaskInstance) -> None:
google_api_endpoint_params = task_instance.xcom_pull(
task_ids=self.google_api_endpoint_params_via_xcom_task_ids,
key=self.google_api_endpoint_params_via_xcom,
)
self.google_api_endpoint_params.update(google_api_endpoint_params)

def _expose_google_api_response_via_xcom(self, task_instance, data):
def _expose_google_api_response_via_xcom(self, task_instance: TaskInstance, data: dict) -> None:
if sys.getsizeof(data) < MAX_XCOM_SIZE:
task_instance.xcom_push(key=self.google_api_response_via_xcom, value=data)
task_instance.xcom_push(key=self.google_api_response_via_xcom or XCOM_RETURN_KEY, value=data)
else:
raise RuntimeError('The size of the downloaded data is too large to push to XCom!')
23 changes: 12 additions & 11 deletions airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""

import json
from typing import Optional, Callable

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.dynamodb import AwsDynamoDBHook
Expand Down Expand Up @@ -64,18 +65,18 @@ class HiveToDynamoDBOperator(BaseOperator):
def __init__( # pylint: disable=too-many-arguments
self,
*,
sql,
table_name,
table_keys,
pre_process=None,
pre_process_args=None,
pre_process_kwargs=None,
region_name=None,
schema='default',
hiveserver2_conn_id='hiveserver2_default',
aws_conn_id='aws_default',
sql: str,
table_name: str,
table_keys: list,
pre_process: Optional[Callable] = None,
pre_process_args: Optional[list] = None,
pre_process_kwargs: Optional[list] = None,
region_name: Optional[str] = None,
schema: str = 'default',
hiveserver2_conn_id: str = 'hiveserver2_default',
aws_conn_id: str = 'aws_default',
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.sql = sql
self.table_name = table_name
Expand Down
20 changes: 10 additions & 10 deletions airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ class ImapAttachmentToS3Operator(BaseOperator):
def __init__(
self,
*,
imap_attachment_name,
s3_key,
imap_check_regex=False,
imap_mail_folder='INBOX',
imap_mail_filter='All',
s3_overwrite=False,
imap_conn_id='imap_default',
s3_conn_id='aws_default',
imap_attachment_name: str,
s3_key: str,
imap_check_regex: bool = False,
imap_mail_folder: str = 'INBOX',
imap_mail_filter: str = 'All',
s3_overwrite: bool = False,
imap_conn_id: str = 'imap_default',
s3_conn_id: str = 'aws_default',
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.imap_attachment_name = imap_attachment_name
self.s3_key = s3_key
Expand All @@ -77,7 +77,7 @@ def __init__(
self.imap_conn_id = imap_conn_id
self.s3_conn_id = s3_conn_id

def execute(self, context):
def execute(self, context) -> None:
"""
This function executes the transfer from the email server (via imap) into s3.
Expand Down
Loading

0 comments on commit 0823d46

Please sign in to comment.
  翻译: