Skip to content

Commit

Permalink
Fix credentials error for S3ToGCSOperator trigger (#37518)
Browse files Browse the repository at this point in the history
* Fix credentials error for S3ToGCSOperator trigger

* fix: safe create StorageTransferServiceAsyncClient()

* fix: test, style, some bugs
  • Loading branch information
korolkevich committed Apr 1, 2024
1 parent 39b684d commit 13e9a0d
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from googleapiclient.errors import HttpError

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -508,14 +509,18 @@ def __init__(self, project_id: str | None = None, **kwargs: Any) -> None:
self.project_id = project_id
self._client: StorageTransferServiceAsyncClient | None = None

def get_conn(self) -> StorageTransferServiceAsyncClient:
async def get_conn(self) -> StorageTransferServiceAsyncClient:
"""
Return async connection to the Storage Transfer Service.
:return: Google Storage Transfer asynchronous client.
"""
if not self._client:
self._client = StorageTransferServiceAsyncClient()
credentials = (await self.get_sync_hook()).get_credentials()
self._client = StorageTransferServiceAsyncClient(
credentials=credentials,
client_info=CLIENT_INFO,
)
return self._client

async def get_jobs(self, job_names: list[str]) -> ListTransferJobsAsyncPager:
Expand All @@ -525,7 +530,7 @@ async def get_jobs(self, job_names: list[str]) -> ListTransferJobsAsyncPager:
:param job_names: (Required) List of names of the jobs to be fetched.
:return: Object that yields Transfer jobs.
"""
client = self.get_conn()
client = await self.get_conn()
jobs_list_request = ListTransferJobsRequest(
filter=json.dumps({"project_id": self.project_id, "job_names": job_names})
)
Expand All @@ -540,7 +545,7 @@ async def get_latest_operation(self, job: TransferJob) -> Message | None:
"""
latest_operation_name = job.latest_operation_name
if latest_operation_name:
client = self.get_conn()
client = await self.get_conn()
response_operation = await client.transport.operations_client.get_operation(latest_operation_name)
operation = TransferOperation.deserialize(response_operation.metadata.value)
return operation
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/google/cloud/transfers/s3_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def transfer_files_async(self, files: list[str], gcs_hook: GCSHook, s3_hook: S3H
self.defer(
trigger=CloudStorageTransferServiceCreateJobsTrigger(
project_id=gcs_hook.project_id,
gcp_conn_id=self.gcp_conn_id,
job_names=job_names,
poll_interval=self.poll_interval,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,19 @@ class CloudStorageTransferServiceCreateJobsTrigger(BaseTrigger):
:param job_names: List of transfer jobs names.
:param project_id: GCP project id.
:param poll_interval: Interval in seconds between polls.
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
"""

def __init__(self, job_names: list[str], project_id: str | None = None, poll_interval: int = 10) -> None:
def __init__(
self,
job_names: list[str],
project_id: str | None = None,
poll_interval: int = 10,
gcp_conn_id: str = "google_cloud_default",
) -> None:
super().__init__()
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.job_names = job_names
self.poll_interval = poll_interval

Expand All @@ -53,6 +61,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"project_id": self.project_id,
"job_names": self.job_names,
"poll_interval": self.poll_interval,
"gcp_conn_id": self.gcp_conn_id,
},
)

Expand Down Expand Up @@ -117,4 +126,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
await asyncio.sleep(self.poll_interval)

def get_async_hook(self) -> CloudDataTransferServiceAsyncHook:
return CloudDataTransferServiceAsyncHook(project_id=self.project_id)
return CloudDataTransferServiceAsyncHook(
project_id=self.project_id,
gcp_conn_id=self.gcp_conn_id,
)
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,20 @@ def hook_async():


class TestCloudDataTransferServiceAsyncHook:
@pytest.mark.asyncio
@mock.patch(f"{TRANSFER_HOOK_PATH}.CloudDataTransferServiceAsyncHook.get_conn")
@mock.patch(f"{TRANSFER_HOOK_PATH}.StorageTransferServiceAsyncClient")
def test_get_conn(self, mock_async_client):
async def test_get_conn(self, mock_async_client, mock_get_conn):
expected_value = "Async Hook"
mock_async_client.return_value = expected_value
mock_get_conn.return_value = expected_value

hook = CloudDataTransferServiceAsyncHook(project_id=TEST_PROJECT_ID)
conn_0 = hook.get_conn()

conn_0 = await hook.get_conn()
assert conn_0 == expected_value

conn_1 = hook.get_conn()
conn_1 = await hook.get_conn()
assert conn_1 == expected_value
assert id(conn_0) == id(conn_1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from airflow.triggers.base import TriggerEvent

PROJECT_ID = "test-project"
GCP_CONN_ID = "google-cloud-default-id"
JOB_0 = "test-job-0"
JOB_1 = "test-job-1"
JOB_NAMES = [JOB_0, JOB_1]
Expand All @@ -51,7 +52,10 @@
@pytest.fixture(scope="session")
def trigger():
return CloudStorageTransferServiceCreateJobsTrigger(
project_id=PROJECT_ID, job_names=JOB_NAMES, poll_interval=POLL_INTERVAL
project_id=PROJECT_ID,
job_names=JOB_NAMES,
poll_interval=POLL_INTERVAL,
gcp_conn_id=GCP_CONN_ID,
)


Expand Down Expand Up @@ -80,6 +84,7 @@ def test_serialize(self, trigger):
"project_id": PROJECT_ID,
"job_names": JOB_NAMES,
"poll_interval": POLL_INTERVAL,
"gcp_conn_id": GCP_CONN_ID,
}

def test_get_async_hook(self, trigger):
Expand Down

0 comments on commit 13e9a0d

Please sign in to comment.
  翻译: