Skip to content

Commit

Permalink
🐛 (BigQueryHook) fix compatibility with sqlalchemy engine (#19508)
Browse files Browse the repository at this point in the history
  • Loading branch information
david30907d authored Feb 6, 2022
1 parent d8c4449 commit 1a77bc6
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 0 deletions.
37 changes: 37 additions & 0 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
_check_google_client_version as gbq_check_google_client_version,
_test_google_api_imports as gbq_test_google_api_imports,
)
from sqlalchemy import create_engine

from airflow.exceptions import AirflowException
from airflow.hooks.dbapi import DbApiHook
Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(
self.running_job_id = None # type: Optional[str]
self.api_resource_configs = api_resource_configs if api_resource_configs else {} # type Dict
self.labels = labels
self.credentials_path = "bigquery_hook_credentials.json"

def get_conn(self) -> "BigQueryConnection":
"""Returns a BigQuery PEP 249 connection object."""
Expand Down Expand Up @@ -150,6 +152,41 @@ def get_client(self, project_id: Optional[str] = None, location: Optional[str] =
credentials=self._get_credentials(),
)

def get_uri(self) -> str:
"""Override DbApiHook get_uri method for get_sqlalchemy_engine()"""
return f"bigquery://{self.project_id}"

def get_sqlalchemy_engine(self, engine_kwargs=None):
"""
Get an sqlalchemy_engine object.
:param engine_kwargs: Kwargs used in :func:`~sqlalchemy.create_engine`.
:return: the created engine.
"""
connection = self.get_connection(self.gcp_conn_id)
if connection.extra_dejson.get("extra__google_cloud_platform__key_path"):
credentials_path = connection.extra_dejson['extra__google_cloud_platform__key_path']
return create_engine(self.get_uri(), credentials_path=credentials_path, **engine_kwargs)
elif connection.extra_dejson.get("extra__google_cloud_platform__keyfile_dict"):
credential_file_content = json.loads(
connection.extra_dejson["extra__google_cloud_platform__keyfile_dict"]
)
return create_engine(self.get_uri(), credentials_info=credential_file_content, **engine_kwargs)
try:
# 1. If the environment variable GOOGLE_APPLICATION_CREDENTIALS is set
# ADC uses the service account key or configuration file that the variable points to.
# 2. If the environment variable GOOGLE_APPLICATION_CREDENTIALS isn't set
# ADC uses the service account that is attached to the resource that is running your code.
return create_engine(self.get_uri(), **engine_kwargs)
except Exception as e:
self.log.error(e)
raise AirflowException(
"For now, we only support instantiating SQLAlchemy engine by"
" using ADC"
", extra__google_cloud_platform__key_path"
"and extra__google_cloud_platform__keyfile_dict"
)

@staticmethod
def _resolve_table_reference(
table_resource: Dict[str, Any],
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
# _check_google_client_version (airflow/providers/google/cloud/hooks/bigquery.py:49)
'pandas-gbq<0.15.0',
pandas_requirement,
'sqlalchemy-bigquery>=1.2.1',
]
grpc = [
'google-auth>=1.0.0, <3.0.0',
Expand Down
13 changes: 13 additions & 0 deletions tests/providers/google/cloud/hooks/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,19 @@ def test_insert_job(self, mock_client, mock_query_job):
)
mock_query_job.from_api_repr.return_value.result.assert_called_once_with()

def test_dbapi_get_uri(self):
assert self.hook.get_uri().startswith('bigquery://')

def test_dbapi_get_sqlalchemy_engine(self):
with pytest.raises(
AirflowException,
match="For now, we only support instantiating SQLAlchemy engine by"
" using ADC"
", extra__google_cloud_platform__key_path"
"and extra__google_cloud_platform__keyfile_dict",
):
self.hook.get_sqlalchemy_engine()


class TestBigQueryTableSplitter(unittest.TestCase):
def test_internal_need_default_project(self):
Expand Down

0 comments on commit 1a77bc6

Please sign in to comment.
  翻译: