Skip to content

Commit

Permalink
Upgrade Mypy to 1.0 (#29468)
Browse files Browse the repository at this point in the history
* Upgrade mypy to 0.991

Most of the changes are related to with `x = None`

* Update airflow/dag_processing/processor.py

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>

* Upgrade Mypy to 1.0

* Resolve newly found typing issues

* Work around empty function in protocol

Mypy seems to not be able to detect an empty default implementation in a
protocol. Supplying 'return None' seems enough to tell Mypy the function
is present. This is better than needing to implement an empty function
everywhere.

* Improve typing declarations

* Fix pat reference in update-common-sql-api-stubs

---------

Co-authored-by: Ash Berlin-Taylor <ash@apache.org>
Co-authored-by: Ash Berlin-Taylor <ash_github@firemirror.com>
  • Loading branch information
3 people committed Feb 13, 2023
1 parent 5e6f8eb commit 41fade2
Show file tree
Hide file tree
Showing 39 changed files with 103 additions and 68 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ repos:
entry: ./scripts/ci/pre_commit/pre_commit_update_common_sql_api_stubs.py
language: python
files: ^scripts/ci/pre_commit/pre_commit_update_common_sql_api\.py|^airflow/providers/common/sql/.*\.pyi?$
additional_dependencies: ['rich>=12.4.4', 'mypy==0.971', 'black==22.12.0', 'jinja2']
additional_dependencies: ['rich>=12.4.4', 'mypy==1.0.0', 'black==22.12.0', 'jinja2']
pass_filenames: false
require_serial: true
- id: update-black-version
Expand Down
3 changes: 1 addition & 2 deletions airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,7 @@ def __init__(

def start(self) -> None:
"""Launch DagFileProcessorManager processor and start DAG parsing loop in manager."""
mp_start_method = self._get_multiprocessing_start_method()
context = multiprocessing.get_context(mp_start_method)
context = self._get_multiprocessing_context()
self._last_parsing_stat_received_at = time.monotonic()

self._parent_signal_conn, child_signal_conn = context.Pipe()
Expand Down
3 changes: 1 addition & 2 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ def _handle_dag_file_processing():

def start(self) -> None:
"""Launch the process and start processing the DAG."""
start_method = self._get_multiprocessing_start_method()
context = multiprocessing.get_context(start_method)
context = self._get_multiprocessing_context()

_parent_channel, _child_channel = context.Pipe(duplex=False)
process = context.Process(
Expand Down
2 changes: 1 addition & 1 deletion airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -
return super()._get_unmap_kwargs(kwargs, strict=False)


class Task(Generic[FParams, FReturn]):
class Task(Protocol, Generic[FParams, FReturn]):
"""Declaration of a @task-decorated callable for type-checking.
An instance of this type inherits the call signature of the decorated
Expand Down
2 changes: 1 addition & 1 deletion airflow/executors/kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
class ResourceVersion:
"""Singleton for tracking resourceVersion from Kubernetes."""

_instance = None
_instance: ResourceVersion | None = None
resource_version: dict[str, str] = {}

def __new__(cls):
Expand Down
4 changes: 2 additions & 2 deletions airflow/hooks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ def get_conn(self) -> Any:

@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
...
return {}

@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
...
return {}


class DiscoverableHook(Protocol):
Expand Down
2 changes: 1 addition & 1 deletion airflow/listeners/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
log = logging.getLogger(__name__)


_listener_manager = None
_listener_manager: ListenerManager | None = None


class ListenerManager:
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from abc import ABCMeta, abstractmethod
from datetime import datetime, timedelta
from inspect import signature
from types import FunctionType
from types import ClassMethodDescriptorType, FunctionType
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -169,7 +169,7 @@ def get_merged_defaults(
class _PartialDescriptor:
"""A descriptor that guards against ``.partial`` being called on Task objects."""

class_method = None
class_method: ClassMethodDescriptorType | None = None

def __get__(
self, obj: BaseOperator, cls: type[BaseOperator] | None = None
Expand Down
2 changes: 0 additions & 2 deletions airflow/providers/google/cloud/hooks/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,8 +702,6 @@ class CloudSQLDatabaseHook(BaseHook):
conn_type = "gcpcloudsqldb"
hook_name = "Google Cloud SQL Database"

_conn = None

def __init__(
self,
gcp_cloudsql_conn_id: str = "google_cloud_sql_default",
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class CloudFunctionsHook(GoogleBaseHook):
keyword arguments rather than positional.
"""

_conn = None
_conn: build | None = None

def __init__(
self,
Expand Down
18 changes: 14 additions & 4 deletions airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,18 @@
from airflow.utils import timezone
from airflow.version import version

try:
# Airflow 2.3 doesn't have this yet
from airflow.typing_compat import ParamSpec
except ImportError:
try:
from typing import ParamSpec # type: ignore[no-redef, attr-defined]
except ImportError:
from typing_extensions import ParamSpec

RT = TypeVar("RT")
T = TypeVar("T", bound=Callable)
FParams = ParamSpec("FParams")

# GCSHook has a method named 'list' (to junior devs: please don't do this), so
# we need to create an alias to prevent Mypy being confused.
Expand All @@ -76,9 +86,9 @@ def _fallback_object_url_to_object_name_and_bucket_name(
:return: Decorator
"""

def _wrapper(func: T):
def _wrapper(func: Callable[FParams, RT]) -> Callable[FParams, RT]:
@functools.wraps(func)
def _inner_wrapper(self: GCSHook, *args, **kwargs) -> RT:
def _inner_wrapper(self, *args, **kwargs) -> RT:
if args:
raise AirflowException(
"You must use keyword arguments in this methods rather than positional"
Expand Down Expand Up @@ -119,9 +129,9 @@ def _inner_wrapper(self: GCSHook, *args, **kwargs) -> RT:

return func(self, *args, **kwargs)

return cast(T, _inner_wrapper)
return cast(Callable[FParams, RT], _inner_wrapper)

return _wrapper
return cast(Callable[[T], T], _wrapper)


# A fake bucket to use in functions decorated by _fallback_object_url_to_object_name_and_bucket_name.
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/life_sciences.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class LifeSciencesHook(GoogleBaseHook):
account from the list granting this role to the originating account.
"""

_conn = None
_conn: build | None = None

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/utils/field_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _sanity_checks(
full_field_path: str,
regexp: str,
allow_empty: bool,
custom_validation: Callable,
custom_validation: Callable | None,
value,
) -> None:
if value is None and field_type != "union":
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/firebase/hooks/firestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class CloudFirestoreHook(GoogleBaseHook):
account from the list granting this role to the originating account.
"""

_conn = None
_conn: build | None = None

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
class GoogleSearchAdsHook(GoogleBaseHook):
"""Hook for Google Search Ads 360."""

_conn = None
_conn: build | None = None

def __init__(
self,
Expand Down
3 changes: 1 addition & 2 deletions airflow/providers/microsoft/psrp/hooks/psrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ class PsrpHook(BaseHook):
or by setting this key as the extra fields of your connection.
"""

_conn = None
_configuration_name = None
_conn: RunspacePool | None = None
_wsman_ref: WeakKeyDictionary[RunspacePool, WSMan] = WeakKeyDictionary()

def __init__(
Expand Down
19 changes: 9 additions & 10 deletions airflow/serialization/serializers/bignum.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@


def serialize(o: object) -> tuple[U, str, int, bool]:
if isinstance(o, Decimal):
name = qualname(o)
_, _, exponent = o.as_tuple()
if exponent >= 0: # No digits after the decimal point.
return int(o), name, __version__, True
# Technically lossy due to floating point errors, but the best we
# can do without implementing a custom encode function.
return float(o), name, __version__, True

return "", "", 0, False
if not isinstance(o, Decimal):
return "", "", 0, False
name = qualname(o)
_, _, exponent = o.as_tuple()
if isinstance(exponent, int) and exponent >= 0: # No digits after the decimal point.
return int(o), name, __version__, True
# Technically lossy due to floating point errors, but the best we
# can do without implementing a custom encode function.
return float(o), name, __version__, True


def deserialize(classname: str, version: int, data: object) -> Decimal:
Expand Down
2 changes: 1 addition & 1 deletion airflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def dispose_orm():
global engine
global Session

if Session:
if Session is not None:
Session.remove()
Session = None
if engine:
Expand Down
7 changes: 4 additions & 3 deletions airflow/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ def timing(cls, stat: str, dt: float | datetime.timedelta, tags: dict[str, str]
@classmethod
def timer(cls, *args, **kwargs) -> TimerProtocol:
"""Timer metric that can be cancelled."""
raise NotImplementedError()


class Timer:
class Timer(TimerProtocol):
"""
Timer that records duration, and optional sends to StatsD backend.
Expand Down Expand Up @@ -360,7 +361,7 @@ def timer(self, stat=None, *args, tags=None, **kwargs):


class _Stats(type):
factory = None
factory: Callable[[], StatsLogger]
instance: StatsLogger | None = None

def __getattr__(cls, name):
Expand All @@ -374,7 +375,7 @@ def __getattr__(cls, name):

def __init__(cls, *args, **kwargs):
super().__init__(cls)
if cls.__class__.factory is None:
if not hasattr(cls.__class__, "factory"):
is_datadog_enabled_defined = conf.has_option("metrics", "statsd_datadog_enabled")
if is_datadog_enabled_defined and conf.getboolean("metrics", "statsd_datadog_enabled"):
cls.__class__.factory = cls.get_dogstatsd_logger
Expand Down
1 change: 1 addition & 0 deletions airflow/timetables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def validate(self) -> None:
:raises: AirflowTimetableInvalid on validation failure.
"""
return

@property
def summary(self) -> str:
Expand Down
6 changes: 6 additions & 0 deletions airflow/utils/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from airflow.utils.context import Context

if typing.TYPE_CHECKING:
import multiprocessing.context

from airflow.models.operator import Operator


Expand All @@ -44,6 +46,10 @@ def _get_multiprocessing_start_method(self) -> str:
raise ValueError("Failed to determine start method")
return method

def _get_multiprocessing_context(self) -> multiprocessing.context.DefaultContext:
mp_start_method = self._get_multiprocessing_start_method()
return multiprocessing.get_context(mp_start_method) # type: ignore


class ResolveMixin:
"""A runtime-resolved value."""
Expand Down
5 changes: 3 additions & 2 deletions airflow/utils/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
@contextlib.contextmanager
def create_session() -> Generator[settings.SASession, None, None]:
"""Contextmanager that will create and teardown a session."""
if not settings.Session:
Session = getattr(settings, "Session", None)
if Session is None:
raise RuntimeError("Session must be set before!")
session = settings.Session()
session = Session()
try:
yield session
session.commit()
Expand Down
4 changes: 2 additions & 2 deletions airflow/utils/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def safe_load(stream: bytes | str | BinaryIO | TextIO) -> Any:
try:
from yaml import CSafeLoader as SafeLoader
except ImportError:
from yaml import SafeLoader # type: ignore[no-redef]
from yaml import SafeLoader # type: ignore[assignment, no-redef]

return orig(stream, SafeLoader)

Expand All @@ -54,7 +54,7 @@ def dump(data: Any, **kwargs) -> str:
try:
from yaml import CSafeDumper as SafeDumper
except ImportError:
from yaml import SafeDumper # type: ignore[no-redef]
from yaml import SafeDumper # type: ignore[assignment, no-redef]

return cast(str, orig(data, Dumper=SafeDumper, **kwargs))

Expand Down
2 changes: 2 additions & 0 deletions airflow/www/extensions/init_appbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# mypy: disable-error-code=var-annotated
from __future__ import annotations

import logging
Expand Down
2 changes: 2 additions & 0 deletions airflow/www/fab_security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# mypy: disable-error-code=var-annotated
from __future__ import annotations

import base64
Expand Down
2 changes: 1 addition & 1 deletion dev/breeze/src/airflow_breeze/utils/click_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
try:
from rich_click import RichGroup as BreezeGroup
except ImportError:
from click import Group as BreezeGroup # type: ignore[misc] # noqa
from click import Group as BreezeGroup # type: ignore[assignment] # noqa
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from rich import print

errors = []
errors: list[str] = []

MY_DIR_PATH = Path(__file__).parent.resolve()

Expand Down
2 changes: 1 addition & 1 deletion scripts/ci/pre_commit/pre_commit_check_order_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from rich import print

errors = []
errors: list[str] = []

MY_DIR_PATH = os.path.dirname(__file__)
SOURCE_DIR_PATH = os.path.abspath(os.path.join(MY_DIR_PATH, os.pardir, os.pardir, os.pardir))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import sys
from pathlib import Path
from typing import Any

from rich.console import Console

Expand All @@ -37,7 +38,7 @@
PREFIX = "apache-airflow-providers-"


errors = []
errors: list[Any] = []


def check_system_test_entry_hidden(provider_index: Path):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from common_precommit_utils import AIRFLOW_SOURCES_ROOT_PATH # isort: skip # noqa E402
from common_precommit_black_utils import black_format # isort: skip # noqa E402

PROVIDERS_ROOT = AIRFLOW_SOURCES_ROOT_PATH / "airflow" / "providers"
PROVIDERS_ROOT = AIRFLOW_SOURCES_ROOT_PATH / "providers"
COMMON_SQL_ROOT = PROVIDERS_ROOT / "common" / "sql"
OUT_DIR = AIRFLOW_SOURCES_ROOT_PATH / "out"

Expand Down
2 changes: 1 addition & 1 deletion scripts/in_container/run_provider_yaml_files_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
)
CORE_INTEGRATIONS = ["SQL", "Local"]

errors = []
errors: list[str] = []

console = Console(width=400, color_system="standard")

Expand Down
Loading

0 comments on commit 41fade2

Please sign in to comment.
  翻译: