Skip to content

Commit

Permalink
Strict type checking for provider google cloud (#11548)
Browse files Browse the repository at this point in the history
  • Loading branch information
mlgruby committed Oct 16, 2020
1 parent df75610 commit 8865d14
Show file tree
Hide file tree
Showing 18 changed files with 60 additions and 54 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/secrets/secret_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
project_id: Optional[str] = None,
sep: str = "-",
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.connections_prefix = connections_prefix
self.variables_prefix = variables_prefix
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/sensors/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain

def poke(self, context):
def poke(self, context: dict) -> bool:
table_uri = '{0}:{1}.{2}'.format(self.project_id, self.dataset_id, self.table_id)
self.log.info('Sensor checks existence of table: %s', table_uri)
hook = BigQueryHook(
Expand Down Expand Up @@ -163,7 +163,7 @@ def __init__(
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain

def poke(self, context):
def poke(self, context: dict) -> bool:
table_uri = '{0}:{1}.{2}'.format(self.project_id, self.dataset_id, self.table_id)
self.log.info('Sensor checks existence of partition: "%s" in table: %s', self.partition_id, table_uri)
hook = BigQueryHook(
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/sensors/bigquery_dts.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
metadata: Optional[Sequence[Tuple[str, str]]] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.run_id = run_id
self.transfer_config_id = transfer_config_id
Expand All @@ -105,7 +105,7 @@ def __init__(
self.gcp_cloud_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def poke(self, context):
def poke(self, context: dict) -> bool:
hook = BiqQueryDataTransferServiceHook(
gcp_conn_id=self.gcp_cloud_conn_id,
impersonation_chain=self.impersonation_chain,
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/sensors/bigtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
self.impersonation_chain = impersonation_chain
super().__init__(**kwargs)

def poke(self, context):
def poke(self, context: dict) -> bool:
hook = BigtableHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
self.gcp_cloud_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def poke(self, context):
def poke(self, context: dict) -> bool:
hook = CloudDataTransferServiceHook(
gcp_conn_id=self.gcp_cloud_conn_id,
impersonation_chain=self.impersonation_chain,
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/sensors/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
self.dataproc_job_id = dataproc_job_id
self.location = location

def poke(self, context):
def poke(self, context: dict) -> bool:
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
job = hook.get_job(job_id=self.dataproc_job_id, location=self.location, project_id=self.project_id)
state = job.status.state
Expand Down
20 changes: 11 additions & 9 deletions airflow/providers/google/cloud/sensors/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain

def poke(self, context):
def poke(self, context: dict) -> bool:
self.log.info('Sensor checks existence of : %s, %s', self.bucket, self.object)
hook = GCSHook(
google_cloud_storage_conn_id=self.google_cloud_conn_id,
Expand Down Expand Up @@ -159,7 +159,7 @@ def __init__(
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain

def poke(self, context):
def poke(self, context: dict) -> bool:
self.log.info('Sensor checks existence of : %s, %s', self.bucket, self.object)
hook = GCSHook(
google_cloud_storage_conn_id=self.google_cloud_conn_id,
Expand Down Expand Up @@ -222,10 +222,10 @@ def __init__(
self.prefix = prefix
self.google_cloud_conn_id = google_cloud_conn_id
self.delegate_to = delegate_to
self._matches = [] # type: List[str]
self._matches: List[str] = []
self.impersonation_chain = impersonation_chain

def poke(self, context):
def poke(self, context: dict) -> bool:
self.log.info('Sensor checks existence of objects: %s, %s', self.bucket, self.prefix)
hook = GCSHook(
google_cloud_storage_conn_id=self.google_cloud_conn_id,
Expand All @@ -235,7 +235,7 @@ def poke(self, context):
self._matches = hook.list(self.bucket, prefix=self.prefix)
return bool(self._matches)

def execute(self, context):
def execute(self, context: dict) -> List[str]:
"""Overridden to allow matches to be passed"""
super().execute(context)
return self._matches
Expand Down Expand Up @@ -332,9 +332,9 @@ def __init__(
self.delegate_to = delegate_to
self.last_activity_time = None
self.impersonation_chain = impersonation_chain
self.hook = None
self.hook: Optional[GCSHook] = None

def _get_gcs_hook(self):
def _get_gcs_hook(self) -> Optional[GCSHook]:
if not self.hook:
self.hook = GCSHook(
gcp_conn_id=self.google_cloud_conn_id,
Expand Down Expand Up @@ -416,5 +416,7 @@ def is_bucket_updated(self, current_objects: Set[str]) -> bool:
return False
return False

def poke(self, context):
return self.is_bucket_updated(set(self._get_gcs_hook().list(self.bucket, prefix=self.prefix)))
def poke(self, context: dict) -> bool:
return self.is_bucket_updated(
set(self._get_gcs_hook().list(self.bucket, prefix=self.prefix)) # type: ignore[union-attr]
)
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/sensors/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ def __init__(

self._return_value = None

def execute(self, context):
def execute(self, context: dict):
"""Overridden to allow messages to be passed"""
super().execute(context)
return self._return_value

def poke(self, context):
def poke(self, context: dict) -> bool:
hook = PubSubHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(
self.location = location
self.impersonation_chain = impersonation_chain

def execute(self, context):
def execute(self, context) -> None:
self.log.info(
'Executing copy of %s into: %s',
self.source_project_dataset_tables,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
self.gzip = gzip
self.impersonation_chain = impersonation_chain

def execute(self, context: Dict):
def execute(self, context: dict):
service = FacebookAdsReportingHook(
facebook_conn_id=self.facebook_conn_id, api_version=self.api_version
)
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/google/cloud/transfers/mssql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""

import decimal
from typing import Dict

from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
Expand Down Expand Up @@ -71,7 +72,7 @@ def query(self):
cursor.execute(self.sql)
return cursor

def field_to_bigquery(self, field):
def field_to_bigquery(self, field) -> Dict[str, str]:
return {
'name': field[0].replace(" ", "_"),
'type': self.type_map.get(field[1], "STRING"),
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/google/cloud/transfers/mysql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import calendar
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Dict

from MySQLdb.constants import FIELD_TYPE

Expand Down Expand Up @@ -84,7 +85,7 @@ def query(self):
cursor.execute(self.sql)
return cursor

def field_to_bigquery(self, field):
def field_to_bigquery(self, field) -> Dict[str, str]:
field_type = self.type_map.get(field[1], "STRING")
# Always allow TIMESTAMP to be nullable. MySQLdb returns None types
# for required fields because some MySQL timestamps can't be
Expand All @@ -96,7 +97,7 @@ def field_to_bigquery(self, field):
'mode': field_mode,
}

def convert_type(self, value, schema_type):
def convert_type(self, value, schema_type: str):
"""
Takes a value from MySQLdb, and converts it to a value that's safe for
JSON/Google Cloud Storage/BigQuery.
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/google/cloud/transfers/postgres_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import json
import time
from decimal import Decimal
from typing import Dict

import pendulum

Expand Down Expand Up @@ -73,7 +74,7 @@ def query(self):
cursor.execute(self.sql, self.parameters)
return cursor

def field_to_bigquery(self, field):
def field_to_bigquery(self, field) -> Dict[str, str]:
return {
'name': field[0],
'type': self.type_map.get(field[1], "STRING"),
Expand Down
9 changes: 5 additions & 4 deletions airflow/providers/google/cloud/transfers/presto_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, List, Tuple
from typing import Any, List, Tuple, Dict

from prestodb.client import PrestoResult
from prestodb.dbapi import Cursor as PrestoCursor

from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
Expand Down Expand Up @@ -76,7 +77,7 @@ def close(self) -> None:
"""Close the cursor now"""
self.cursor.close()

def execute(self, *args, **kwargs):
def execute(self, *args, **kwargs) -> PrestoResult:
"""Prepare and execute a database operation (query or command)."""
self.initialized = False
self.rows = []
Expand Down Expand Up @@ -109,7 +110,7 @@ def fetchone(self) -> Any:
return self.rows.pop(0)
return self.cursor.fetchone()

def fetchmany(self, size=None) -> List[Any]:
def fetchmany(self, size=None) -> list:
"""
Fetch the next set of rows of a query result, returning a sequence of sequences
(e.g. a list of tuples). An empty sequence is returned when no more rows are available.
Expand Down Expand Up @@ -194,7 +195,7 @@ def query(self):
cursor.execute(self.sql)
return _PrestoToGCSPrestoCursorAdapter(cursor)

def field_to_bigquery(self, field):
def field_to_bigquery(self, field) -> Dict[str, str]:
"""Convert presto field type to BigQuery field type."""
clear_field_type = field[1].upper()
# remove type argument e.g. DECIMAL(2, 10) => DECIMAL
Expand Down
32 changes: 16 additions & 16 deletions airflow/providers/google/cloud/transfers/sql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,23 @@ class BaseSQLToGCSOperator(BaseOperator):
def __init__(
self,
*, # pylint: disable=too-many-arguments
sql,
bucket,
filename,
schema_filename=None,
approx_max_file_size_bytes=1900000000,
export_format='json',
field_delimiter=',',
null_marker=None,
gzip=False,
schema=None,
parameters=None,
gcp_conn_id='google_cloud_default',
google_cloud_storage_conn_id=None,
delegate_to=None,
sql: str,
bucket: str,
filename: str,
schema_filename: Optional[str] = None,
approx_max_file_size_bytes: int = 1900000000,
export_format: str = 'json',
field_delimiter: str = ',',
null_marker: Optional[str] = None,
gzip: bool = False,
schema: Optional[Union[str, list]] = None,
parameters: Optional[dict] = None,
gcp_conn_id: str = 'google_cloud_default',
google_cloud_storage_conn_id: Optional[str] = None,
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
):
) -> None:
super().__init__(**kwargs)

if google_cloud_storage_conn_id:
Expand Down Expand Up @@ -170,7 +170,7 @@ def execute(self, context):
for tmp_file in files_to_upload:
tmp_file['file_handle'].close()

def convert_types(self, schema, col_type_dict, row):
def convert_types(self, schema, col_type_dict, row) -> list:
"""Convert values from DBAPI to output-friendly formats."""
return [self.convert_type(value, col_type_dict.get(name)) for name, value in zip(schema, row)]

Expand Down
14 changes: 7 additions & 7 deletions airflow/providers/google/cloud/utils/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import logging
import tempfile
from contextlib import ExitStack, contextmanager
from typing import Collection, Dict, Optional, Sequence, Tuple, Union
from typing import Collection, Dict, Optional, Sequence, Tuple, Union, Generator
from urllib.parse import urlencode

import google.auth
Expand Down Expand Up @@ -113,7 +113,7 @@ def provide_gcp_connection(
key_file_path: Optional[str] = None,
scopes: Optional[Sequence] = None,
project_id: Optional[str] = None,
):
) -> Generator:
"""
Context manager that provides a temporary value of :envvar:`AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT`
connection. It build a new connection that includes path to provided service json,
Expand All @@ -140,7 +140,7 @@ def provide_gcp_conn_and_credentials(
key_file_path: Optional[str] = None,
scopes: Optional[Sequence] = None,
project_id: Optional[str] = None,
):
) -> Generator:
"""
Context manager that provides both:
Expand Down Expand Up @@ -212,7 +212,7 @@ def __init__(
disable_logging: bool = False,
target_principal: Optional[str] = None,
delegates: Optional[Sequence[str]] = None,
):
) -> None:
super().__init__()
if key_path and keyfile_dict:
raise AirflowException(
Expand All @@ -227,7 +227,7 @@ def __init__(
self.target_principal = target_principal
self.delegates = delegates

def get_credentials_and_project(self):
def get_credentials_and_project(self) -> Tuple[google.auth.credentials.Credentials, str]:
"""
Get current credentials and project ID.
Expand Down Expand Up @@ -295,11 +295,11 @@ def _get_credentials_using_adc(self):
credentials, project_id = google.auth.default(scopes=self.scopes)
return credentials, project_id

def _log_info(self, *args, **kwargs):
def _log_info(self, *args, **kwargs) -> None:
if not self.disable_logging:
self.log.info(*args, **kwargs)

def _log_debug(self, *args, **kwargs):
def _log_debug(self, *args, **kwargs) -> None:
if not self.disable_logging:
self.log.debug(*args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/utils/field_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _sanitize(self, dictionary, remaining_field_spec, current_path):
child,
)

def sanitize(self, body):
def sanitize(self, body) -> None:
"""
Sanitizes the body according to specification.
"""
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/utils/field_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def _validate_field(self, validation_spec, dictionary_to_validate, parent=None,
)
return True

def validate(self, body_to_validate):
def validate(self, body_to_validate: dict) -> None:
"""
Validates if the body (dictionary) follows specification that the validator was
instantiated with. Raises ValidationSpecificationException or
Expand Down

0 comments on commit 8865d14

Please sign in to comment.
  翻译: