Skip to content

Commit

Permalink
Fix mypy errors in Google Cloud provider (#20611)
Browse files Browse the repository at this point in the history
Part of #19891

Another attempt to clean-up all MyPy errors in Google Provider.
  • Loading branch information
potiuk committed Dec 31, 2021
1 parent 2d09202 commit a22d5bd
Show file tree
Hide file tree
Showing 28 changed files with 220 additions and 168 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pathlib import Path
from typing import Any, Dict

import yaml
from future.backports.urllib.parse import urlparse

from airflow import models
Expand Down Expand Up @@ -190,7 +191,7 @@
create_build_from_file = CloudBuildCreateBuildOperator(
task_id="create_build_from_file",
project_id=GCP_PROJECT_ID,
build=str(CURRENT_FOLDER.joinpath('example_cloud_build.yaml')),
build=yaml.safe_load((Path(CURRENT_FOLDER) / 'example_cloud_build.yaml').read_text()),
params={'name': 'Airflow'},
)
# [END howto_operator_gcp_create_build_from_yaml_body]
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/google/cloud/example_dags/example_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from google.api_core.retry import Retry
from google.cloud.tasks_v2.types import Queue
from google.protobuf import timestamp_pb2
from google.protobuf.field_mask_pb2 import FieldMask

from airflow import models
from airflow.models.baseoperator import chain
Expand Down Expand Up @@ -136,7 +137,7 @@
task_queue=Queue(stackdriver_logging_config=dict(sampling_ratio=1)),
location=LOCATION,
queue_name=QUEUE_ID,
update_mask={"paths": ["stackdriver_logging_config.sampling_ratio"]},
update_mask=FieldMask(paths=["stackdriver_logging_config.sampling_ratio"]),
task_id="update_queue",
)
# [END update_queue]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import os
from datetime import datetime

from google.protobuf.field_mask_pb2 import FieldMask

from airflow import DAG
from airflow.providers.google.cloud.operators.workflows import (
WorkflowsCancelExecutionOperator,
Expand Down Expand Up @@ -102,7 +104,7 @@
location=LOCATION,
project_id=PROJECT_ID,
workflow_id=WORKFLOW_ID,
update_mask={"paths": ["name", "description"]},
update_mask=FieldMask(paths=["name", "description"]),
)
# [END how_to_update_workflow]

Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/google/cloud/hooks/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,21 +262,21 @@ def delete_operation(self, name: str) -> dict:

return resp

def poll_operation_until_done(self, name: str, polling_interval_in_seconds: int) -> Dict:
def poll_operation_until_done(self, name: str, polling_interval_in_seconds: float) -> Dict:
"""
Poll backup operation state until it's completed.
:param name: the name of the operation resource
:type name: str
:param polling_interval_in_seconds: The number of seconds to wait before calling another request.
:type polling_interval_in_seconds: int
:type polling_interval_in_seconds: float
:return: a resource operation instance.
:rtype: dict
"""
while True:
result = self.get_operation(name) # type: Dict
result: Dict = self.get_operation(name)

state = result['metadata']['common']['state'] # type: str
state: str = result['metadata']['common']['state']
if state == 'PROCESSING':
self.log.info(
'Operation is processing. Re-polling state in %s seconds', polling_interval_in_seconds
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/dlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def create_inspect_template(
self,
organization_id: Optional[str] = None,
project_id: Optional[str] = None,
inspect_template: Optional[Union[dict, InspectTemplate]] = None,
inspect_template: Optional[InspectTemplate] = None,
template_id: Optional[str] = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
Expand Down
41 changes: 31 additions & 10 deletions airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from io import BytesIO
from os import path
from tempfile import NamedTemporaryFile
from typing import Callable, List, Optional, Sequence, Set, Tuple, TypeVar, Union, cast
from typing import Callable, List, Optional, Sequence, Set, Tuple, TypeVar, Union, cast, overload
from urllib.parse import urlparse

from google.api_core.exceptions import NotFound
Expand Down Expand Up @@ -273,6 +273,30 @@ def rewrite(
destination_bucket.name, # type: ignore[attr-defined]
)

@overload
def download(
self,
bucket_name: str,
object_name: str,
filename: None = None,
chunk_size: Optional[int] = None,
timeout: Optional[int] = DEFAULT_TIMEOUT,
num_max_attempts: Optional[int] = 1,
) -> bytes:
...

@overload
def download(
self,
bucket_name: str,
object_name: str,
filename: str,
chunk_size: Optional[int] = None,
timeout: Optional[int] = DEFAULT_TIMEOUT,
num_max_attempts: Optional[int] = 1,
) -> str:
...

def download(
self,
bucket_name: str,
Expand Down Expand Up @@ -366,15 +390,12 @@ def download_as_byte_array(
:type num_max_attempts: int
"""
# We do not pass filename, so will never receive string as response
return cast(
bytes,
self.download(
bucket_name=bucket_name,
object_name=object_name,
chunk_size=chunk_size,
timeout=timeout,
num_max_attempts=num_max_attempts,
),
return self.download(
bucket_name=bucket_name,
object_name=object_name,
chunk_size=chunk_size,
timeout=timeout,
num_max_attempts=num_max_attempts,
)

@_fallback_object_url_to_object_name_and_bucket_name()
Expand Down
24 changes: 12 additions & 12 deletions airflow/providers/google/cloud/hooks/stackdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def list_alert_policies(
order_by: Optional[str] = None,
page_size: Optional[int] = None,
retry: Optional[str] = DEFAULT,
timeout: Optional[float] = DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> Any:
"""
Expand Down Expand Up @@ -135,7 +135,7 @@ def _toggle_policy_status(
project_id: str = PROVIDE_PROJECT_ID,
filter_: Optional[str] = None,
retry: Optional[str] = DEFAULT,
timeout: Optional[float] = DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
):
client = self._get_policy_client()
Expand All @@ -157,7 +157,7 @@ def enable_alert_policies(
project_id: str = PROVIDE_PROJECT_ID,
filter_: Optional[str] = None,
retry: Optional[str] = DEFAULT,
timeout: Optional[float] = DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> None:
"""
Expand Down Expand Up @@ -195,7 +195,7 @@ def disable_alert_policies(
project_id: str = PROVIDE_PROJECT_ID,
filter_: Optional[str] = None,
retry: Optional[str] = DEFAULT,
timeout: Optional[float] = DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> None:
"""
Expand Down Expand Up @@ -233,7 +233,7 @@ def upsert_alert(
alerts: str,
project_id: str = PROVIDE_PROJECT_ID,
retry: Optional[str] = DEFAULT,
timeout: Optional[float] = DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> None:
"""
Expand Down Expand Up @@ -334,7 +334,7 @@ def delete_alert_policy(
self,
name: str,
retry: Optional[str] = DEFAULT,
timeout: Optional[float] = DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> None:
"""
Expand Down Expand Up @@ -370,7 +370,7 @@ def list_notification_channels(
order_by: Optional[str] = None,
page_size: Optional[int] = None,
retry: Optional[str] = DEFAULT,
timeout: Optional[str] = DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> Any:
"""
Expand Down Expand Up @@ -437,7 +437,7 @@ def _toggle_channel_status(
project_id: str = PROVIDE_PROJECT_ID,
filter_: Optional[str] = None,
retry: Optional[str] = DEFAULT,
timeout: Optional[str] = DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> None:
client = self._get_channel_client()
Expand All @@ -461,7 +461,7 @@ def enable_notification_channels(
project_id: str = PROVIDE_PROJECT_ID,
filter_: Optional[str] = None,
retry: Optional[str] = DEFAULT,
timeout: Optional[str] = DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> None:
"""
Expand Down Expand Up @@ -499,7 +499,7 @@ def disable_notification_channels(
project_id: str,
filter_: Optional[str] = None,
retry: Optional[str] = DEFAULT,
timeout: Optional[str] = DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> None:
"""
Expand Down Expand Up @@ -537,7 +537,7 @@ def upsert_channel(
channels: str,
project_id: str,
retry: Optional[str] = DEFAULT,
timeout: Optional[float] = DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> dict:
"""
Expand Down Expand Up @@ -603,7 +603,7 @@ def delete_notification_channel(
self,
name: str,
retry: Optional[str] = DEFAULT,
timeout: Optional[str] = DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> None:
"""
Expand Down
17 changes: 7 additions & 10 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import uuid
import warnings
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Set, SupportsAbs, Union, cast
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Set, SupportsAbs, Union

import attr
from google.api_core.exceptions import Conflict
Expand Down Expand Up @@ -620,7 +620,7 @@ def __init__(
sql: Union[str, Iterable],
destination_dataset_table: Optional[str] = None,
write_disposition: str = 'WRITE_EMPTY',
allow_large_results: Optional[bool] = False,
allow_large_results: bool = False,
flatten_results: Optional[bool] = None,
gcp_conn_id: str = 'google_cloud_default',
bigquery_conn_id: Optional[str] = None,
Expand Down Expand Up @@ -694,7 +694,7 @@ def execute(self, context: 'Context'):
impersonation_chain=self.impersonation_chain,
)
if isinstance(self.sql, str):
job_id = self.hook.run_query(
job_id: Union[str, List[str]] = self.hook.run_query(
sql=self.sql,
destination_dataset_table=self.destination_dataset_table,
write_disposition=self.write_disposition,
Expand Down Expand Up @@ -1211,10 +1211,7 @@ def execute(self, context: 'Context') -> None:
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
schema_fields_bytes_or_string = gcs_hook.download(self.bucket, self.schema_object)
if hasattr(schema_fields_bytes_or_string, 'decode'):
schema_fields_bytes_or_string = cast(bytes, schema_fields_bytes_or_string).decode("utf-8")
schema_fields = json.loads(schema_fields_bytes_or_string)
schema_fields = json.loads(gcs_hook.download(self.bucket, self.schema_object).decode("utf-8"))
else:
schema_fields = self.schema_fields

Expand Down Expand Up @@ -2114,9 +2111,9 @@ def __init__(
self,
*,
schema_fields_updates: List[Dict[str, Any]],
include_policy_tags: Optional[bool] = False,
dataset_id: Optional[str] = None,
table_id: Optional[str] = None,
dataset_id: str,
table_id: str,
include_policy_tags: bool = False,
project_id: Optional[str] = None,
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
Expand Down
24 changes: 14 additions & 10 deletions airflow/providers/google/cloud/operators/cloud_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class CloudBuildCreateBuildOperator(BaseOperator):
def __init__(
self,
*,
build: Optional[Union[Dict, Build, str]] = None,
build: Optional[Union[Dict, Build]] = None,
body: Optional[Dict] = None,
project_id: Optional[str] = None,
wait: bool = True,
Expand All @@ -171,28 +171,32 @@ def __init__(
**kwargs,
) -> None:
super().__init__(**kwargs)
self.build = build
# Not template fields to keep original value
self.build_raw = build
self.body = body
self.project_id = project_id
self.wait = wait
self.retry = retry
self.timeout = timeout
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.body = body

if self.body and self.build:
raise AirflowException("Either build or body should be passed.")

if self.body:
if body and build:
raise AirflowException("You should not pass both build or body parameters. Both are set.")
if body is not None:
warnings.warn(
"The body parameter has been deprecated. You should pass body using the build parameter.",
DeprecationWarning,
stacklevel=4,
)
self.build = self.build_raw = self.body
actual_build = body
else:
if build is None:
raise AirflowException("You should pass one of the build or body parameters. Both are None")
actual_build = build

self.build = actual_build
# Not template fields to keep original value
self.build_raw = actual_build

def prepare_template(self) -> None:
# if no file is specified, skip
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/operators/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator
from airflow.models import BaseOperator, Connection
from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLDatabaseHook, CloudSQLHook
from airflow.providers.google.cloud.utils.field_validator import GcpBodyFieldValidator
from airflow.providers.mysql.hooks.mysql import MySqlHook
Expand Down Expand Up @@ -1044,7 +1044,7 @@ def __init__(
self.gcp_cloudsql_conn_id = gcp_cloudsql_conn_id
self.autocommit = autocommit
self.parameters = parameters
self.gcp_connection = None
self.gcp_connection: Optional[Connection] = None

def _execute_query(
self, hook: CloudSQLDatabaseHook, database_hook: Union[PostgresHook, MySqlHook]
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/operators/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def __init__(
poll_sleep: int = 10,
job_class: Optional[str] = None,
check_if_running: CheckJobRunning = CheckJobRunning.WaitForRun,
multiple_jobs: Optional[bool] = None,
multiple_jobs: bool = False,
cancel_timeout: Optional[int] = 10 * 60,
wait_until_finished: Optional[bool] = None,
**kwargs,
Expand Down
Loading

0 comments on commit a22d5bd

Please sign in to comment.
  翻译: