Skip to content

Commit

Permalink
Add "BOOLEAN" to type_map of MSSQLToGCSOperator, fix incorrect bit->i…
Browse files Browse the repository at this point in the history
…nt type conversion by specifying BIT fields explicitly (#29902)
  • Loading branch information
shahar1 committed Mar 4, 2023
1 parent 5a632f7 commit 035ad26
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 9 deletions.
20 changes: 18 additions & 2 deletions airflow/providers/google/cloud/transfers/mssql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import datetime
import decimal
from typing import Sequence

from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
Expand All @@ -29,6 +30,10 @@ class MSSQLToGCSOperator(BaseSQLToGCSOperator):
"""Copy data from Microsoft SQL Server to Google Cloud Storage
in JSON, CSV or Parquet format.
:param bit_fields: Sequence of fields names of MSSQL "BIT" data type,
to be interpreted in the schema as "BOOLEAN". "BIT" fields that won't
be included in this sequence, will be interpreted as "INTEGER" by
default.
:param mssql_conn_id: Reference to a specific MSSQL hook.
**Example**:
Expand All @@ -39,6 +44,7 @@ class MSSQLToGCSOperator(BaseSQLToGCSOperator):
export_customers = MsSqlToGoogleCloudStorageOperator(
task_id='export_customers',
sql='SELECT * FROM dbo.Customers;',
bit_fields=['some_bit_field', 'another_bit_field'],
bucket='mssql-export',
filename='data/customers/export.json',
schema_filename='schemas/export.json',
Expand All @@ -55,11 +61,18 @@ class MSSQLToGCSOperator(BaseSQLToGCSOperator):

ui_color = "#e0a98c"

type_map = {3: "INTEGER", 4: "TIMESTAMP", 5: "NUMERIC"}
type_map = {2: "BOOLEAN", 3: "INTEGER", 4: "TIMESTAMP", 5: "NUMERIC"}

def __init__(self, *, mssql_conn_id="mssql_default", **kwargs):
def __init__(
self,
*,
bit_fields: Sequence[str] | None = None,
mssql_conn_id="mssql_default",
**kwargs,
):
super().__init__(**kwargs)
self.mssql_conn_id = mssql_conn_id
self.bit_fields = bit_fields if bit_fields else []

def query(self):
"""
Expand All @@ -74,6 +87,9 @@ def query(self):
return cursor

def field_to_bigquery(self, field) -> dict[str, str]:
if field[0] in self.bit_fields:
field = (field[0], 2)

return {
"name": field[0].replace(" ", "_"),
"type": self.type_map.get(field[1], "STRING"),
Expand Down
37 changes: 30 additions & 7 deletions tests/providers/google/cloud/transfers/test_mssql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,35 @@
JSON_FILENAME = "test_{}.ndjson"
GZIP = False

ROWS = [("mock_row_content_1", 42), ("mock_row_content_2", 43), ("mock_row_content_3", 44)]
ROWS = [
("mock_row_content_1", 42, True, True),
("mock_row_content_2", 43, False, False),
("mock_row_content_3", 44, True, True),
]
CURSOR_DESCRIPTION = (
("some_str", 0, None, None, None, None, None),
("some_num", 3, None, None, None, None, None),
("some_binary", 2, None, None, None, None, None),
("some_bit", 3, None, None, None, None, None),
)
NDJSON_LINES = [
b'{"some_num": 42, "some_str": "mock_row_content_1"}\n',
b'{"some_num": 43, "some_str": "mock_row_content_2"}\n',
b'{"some_num": 44, "some_str": "mock_row_content_3"}\n',
b'{"some_binary": true, "some_bit": true, "some_num": 42, "some_str": "mock_row_content_1"}\n',
b'{"some_binary": false, "some_bit": false, "some_num": 43, "some_str": "mock_row_content_2"}\n',
b'{"some_binary": true, "some_bit": true, "some_num": 44, "some_str": "mock_row_content_3"}\n',
]
SCHEMA_FILENAME = "schema_test.json"
SCHEMA_JSON = [
b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, ',
b'{"mode": "NULLABLE", "name": "some_num", "type": "INTEGER"}]',
b'{"mode": "NULLABLE", "name": "some_num", "type": "INTEGER"}, ',
b'{"mode": "NULLABLE", "name": "some_binary", "type": "BOOLEAN"}, ',
b'{"mode": "NULLABLE", "name": "some_bit", "type": "BOOLEAN"}]',
]

SCHEMA_JSON_BIT_FIELDS = [
b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, ',
b'{"mode": "NULLABLE", "name": "some_num", "type": "INTEGER"}, ',
b'{"mode": "NULLABLE", "name": "some_binary", "type": "BOOLEAN"}, ',
b'{"mode": "NULLABLE", "name": "some_bit", "type": "INTEGER"}]',
]


Expand Down Expand Up @@ -148,7 +163,10 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metada

@mock.patch("airflow.providers.google.cloud.transfers.mssql_to_gcs.MsSqlHook")
@mock.patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook")
def test_schema_file(self, gcs_hook_mock_class, mssql_hook_mock_class):
@pytest.mark.parametrize(
"bit_fields,schema_json", [(None, SCHEMA_JSON), (["bit_fields", SCHEMA_JSON_BIT_FIELDS])]
)
def test_schema_file(self, gcs_hook_mock_class, mssql_hook_mock_class, bit_fields, schema_json):
"""Test writing schema files."""
mssql_hook_mock = mssql_hook_mock_class.return_value
mssql_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
Expand All @@ -164,7 +182,12 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None):
gcs_hook_mock.upload.side_effect = _assert_upload

op = MSSQLToGCSOperator(
task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=JSON_FILENAME, schema_filename=SCHEMA_FILENAME
task_id=TASK_ID,
sql=SQL,
bucket=BUCKET,
filename=JSON_FILENAME,
schema_filename=SCHEMA_FILENAME,
bit_fields=["some_bit"],
)
op.execute(None)

Expand Down

0 comments on commit 035ad26

Please sign in to comment.
  翻译: