Skip to content

Commit

Permalink
Add Parquet data type to BaseSQLToGCSOperator (#13359)
Browse files Browse the repository at this point in the history
  • Loading branch information
tuanchris authored Dec 31, 2020
1 parent 10be375 commit 406181d
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 9 deletions.
40 changes: 39 additions & 1 deletion airflow/providers/google/cloud/transfers/sql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from tempfile import NamedTemporaryFile
from typing import Optional, Sequence, Union

import pyarrow as pa
import pyarrow.parquet as pq
import unicodecsv as csv

from airflow.models import BaseOperator
Expand Down Expand Up @@ -185,6 +187,8 @@ def _write_local_data_files(self, cursor):
tmp_file_handle = NamedTemporaryFile(delete=True)
if self.export_format == 'csv':
file_mime_type = 'text/csv'
elif self.export_format == 'parquet':
file_mime_type = 'application/octet-stream'
else:
file_mime_type = 'application/json'
files_to_upload = [
Expand All @@ -198,6 +202,9 @@ def _write_local_data_files(self, cursor):

if self.export_format == 'csv':
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
if self.export_format == 'parquet':
parquet_schema = self._convert_parquet_schema(cursor)
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)

for row in cursor:
# Convert datetime objects to utc seconds, and decimals to floats.
Expand All @@ -208,6 +215,12 @@ def _write_local_data_files(self, cursor):
if self.null_marker is not None:
row = [value if value is not None else self.null_marker for value in row]
csv_writer.writerow(row)
elif self.export_format == 'parquet':
if self.null_marker is not None:
row = [value if value is not None else self.null_marker for value in row]
row_pydic = {col: [value] for col, value in zip(schema, row)}
tbl = pa.Table.from_pydict(row_pydic, parquet_schema)
parquet_writer.write_table(tbl)
else:
row_dict = dict(zip(schema, row))

Expand All @@ -232,7 +245,8 @@ def _write_local_data_files(self, cursor):
self.log.info("Current file count: %d", len(files_to_upload))
if self.export_format == 'csv':
csv_writer = self._configure_csv_file(tmp_file_handle, schema)

if self.export_format == 'parquet':
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
return files_to_upload

def _configure_csv_file(self, file_handle, schema):
Expand All @@ -243,6 +257,30 @@ def _configure_csv_file(self, file_handle, schema):
csv_writer.writerow(schema)
return csv_writer

def _configure_parquet_file(self, file_handle, parquet_schema):
parquet_writer = pq.ParquetWriter(file_handle.name, parquet_schema)
return parquet_writer

def _convert_parquet_schema(self, cursor):
type_map = {
'INTERGER': pa.int64(),
'FLOAT': pa.float64(),
'NUMERIC': pa.float64(),
'BIGNUMERIC': pa.float64(),
'BOOL': pa.bool_(),
'STRING': pa.string(),
'BYTES': pa.binary(),
'DATE': pa.date32(),
'DATETIME': pa.date64(),
'TIMESTAMP': pa.timestamp('s'),
}

columns = [field[0] for field in cursor.description]
bq_types = [self.field_to_bigquery(field) for field in cursor.description]
pq_types = [type_map.get(bq_type, pa.string()) for bq_type in bq_types]
parquet_schema = pa.schema(zip(columns, pq_types))
return parquet_schema

@abc.abstractmethod
def query(self):
"""Execute DBAPI query."""
Expand Down
110 changes: 102 additions & 8 deletions tests/providers/google/cloud/transfers/test_sql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import json
import unittest
from unittest import mock
from unittest.mock import Mock
from unittest.mock import MagicMock, Mock

import pandas as pd
import unicodecsv as csv

from airflow.providers.google.cloud.hooks.gcs import GCSHook
Expand All @@ -36,6 +37,11 @@
]
COLUMNS = ["column_a", "column_b", "column_c"]
ROW = ["convert_type_return_value", "convert_type_return_value", "convert_type_return_value"]
CURSOR_DESCRIPTION = [
("column_a", "3", 0, 0, 0, 0, False),
("column_b", "253", 0, 0, 0, 0, False),
("column_c", "10", 0, 0, 0, 0, False),
]
TMP_FILE_NAME = "temp-file"
INPUT_DATA = [
["101", "school", "2015-01-01"],
Expand All @@ -52,13 +58,15 @@
SCHEMA_FILE = "schema_file.json"
APP_JSON = "application/json"

OUTPUT_DF = pd.DataFrame([['convert_type_return_value'] * 3] * 3, columns=COLUMNS)


class DummySQLToGCSOperator(BaseSQLToGCSOperator):
def field_to_bigquery(self, field):
pass

def convert_type(self, value, schema_type):
pass
return 'convert_type_return_value'

def query(self):
pass
Expand All @@ -69,13 +77,10 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
@mock.patch.object(csv.writer, "writerow")
@mock.patch.object(GCSHook, "upload")
@mock.patch.object(DummySQLToGCSOperator, "query")
@mock.patch.object(DummySQLToGCSOperator, "field_to_bigquery")
@mock.patch.object(DummySQLToGCSOperator, "convert_type")
def test_exec(
self, mock_convert_type, mock_field_to_bigquery, mock_query, mock_upload, mock_writerow, mock_tempfile
):
def test_exec(self, mock_convert_type, mock_query, mock_upload, mock_writerow, mock_tempfile):
cursor_mock = Mock()
cursor_mock.description = [("column_a", "3"), ("column_b", "253"), ("column_c", "10")]
cursor_mock.description = CURSOR_DESCRIPTION
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
mock_query.return_value = cursor_mock
mock_convert_type.return_value = "convert_type_return_value"
Expand All @@ -99,6 +104,7 @@ def test_exec(

mock_tempfile.return_value = mock_file

# Test CSV
operator = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
Expand All @@ -109,7 +115,7 @@ def test_exec(
export_format="csv",
gzip=True,
schema=SCHEMA,
google_cloud_storage_conn_id='google_cloud_default',
gcp_conn_id='google_cloud_default',
)
operator.execute(context=dict())

Expand Down Expand Up @@ -140,6 +146,7 @@ def test_exec(

cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))

# Test JSON
operator = DummySQLToGCSOperator(
sql=SQL, bucket=BUCKET, filename=FILENAME, task_id=TASK_ID, export_format="json", schema=SCHEMA
)
Expand All @@ -160,6 +167,27 @@ def test_exec(
mock_upload.assert_called_once_with(BUCKET, FILENAME, TMP_FILE_NAME, mime_type=APP_JSON, gzip=False)
mock_close.assert_called_once()

mock_query.reset_mock()
mock_flush.reset_mock()
mock_upload.reset_mock()
mock_close.reset_mock()
cursor_mock.reset_mock()

cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))

# Test parquet
operator = DummySQLToGCSOperator(
sql=SQL, bucket=BUCKET, filename=FILENAME, task_id=TASK_ID, export_format="parquet", schema=SCHEMA
)
operator.execute(context=dict())

mock_query.assert_called_once()
mock_flush.assert_called_once()
mock_upload.assert_called_once_with(
BUCKET, FILENAME, TMP_FILE_NAME, mime_type='application/octet-stream', gzip=False
)
mock_close.assert_called_once()

# Test null marker
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
mock_convert_type.return_value = None
Expand All @@ -182,3 +210,69 @@ def test_exec(
mock.call(["NULL", "NULL", "NULL"]),
]
)

def test__write_local_data_files_csv(self):
op = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
filename=FILENAME,
task_id=TASK_ID,
schema_filename=SCHEMA_FILE,
export_format="csv",
gzip=False,
schema=SCHEMA,
gcp_conn_id='google_cloud_default',
)
cursor = MagicMock()
cursor.__iter__.return_value = INPUT_DATA
cursor.description = CURSOR_DESCRIPTION

files = op._write_local_data_files(cursor)
file = files[0]['file_handle']
file.flush()
df = pd.read_csv(file.name)
assert df.equals(OUTPUT_DF)

def test__write_local_data_files_json(self):
op = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
filename=FILENAME,
task_id=TASK_ID,
schema_filename=SCHEMA_FILE,
export_format="json",
gzip=False,
schema=SCHEMA,
gcp_conn_id='google_cloud_default',
)
cursor = MagicMock()
cursor.__iter__.return_value = INPUT_DATA
cursor.description = CURSOR_DESCRIPTION

files = op._write_local_data_files(cursor)
file = files[0]['file_handle']
file.flush()
df = pd.read_json(file.name, orient='records', lines=True)
assert df.equals(OUTPUT_DF)

def test__write_local_data_files_parquet(self):
op = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
filename=FILENAME,
task_id=TASK_ID,
schema_filename=SCHEMA_FILE,
export_format="parquet",
gzip=False,
schema=SCHEMA,
gcp_conn_id='google_cloud_default',
)
cursor = MagicMock()
cursor.__iter__.return_value = INPUT_DATA
cursor.description = CURSOR_DESCRIPTION

files = op._write_local_data_files(cursor)
file = files[0]['file_handle']
file.flush()
df = pd.read_parquet(file.name)
assert df.equals(OUTPUT_DF)

0 comments on commit 406181d

Please sign in to comment.
  翻译: