Skip to content

Commit

Permalink
Add a new param to set parquet row group size in `BaseSQLToGCSOperato…
Browse files Browse the repository at this point in the history
…r` (#31831)

* Add parquet_row_group_size to BaseSQLToGCSOperator operator

Signed-off-by: Hussein Awala <hussein@awala.fr>

* add a unit test

Signed-off-by: Hussein Awala <hussein@awala.fr>

* Improve docstring

Signed-off-by: Hussein Awala <hussein@awala.fr>

---------

Signed-off-by: Hussein Awala <hussein@awala.fr>
  • Loading branch information
hussein-awala authored Jun 14, 2023
1 parent ee83a2f commit b502e66
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 5 deletions.
39 changes: 34 additions & 5 deletions airflow/providers/google/cloud/transfers/sql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import json
import os
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Any, Sequence

import pyarrow as pa
import pyarrow.parquet as pq
Expand Down Expand Up @@ -82,6 +82,10 @@ class BaseSQLToGCSOperator(BaseOperator):
:param write_on_empty: Optional parameter to specify whether to write a file if the
export does not return any rows. Default is False so we will not write a file
if the export returns no rows.
:param parquet_row_group_size: The approximate number of rows in each row group
when using parquet format. Using a large row group size can reduce the file size
and improve the performance of reading the data, but it needs more memory to
execute the operator. (default: 1)
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -119,6 +123,7 @@ def __init__(
exclude_columns: set | None = None,
partition_columns: list | None = None,
write_on_empty: bool = False,
parquet_row_group_size: int = 1,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -143,6 +148,7 @@ def __init__(
self.exclude_columns = exclude_columns
self.partition_columns = partition_columns
self.write_on_empty = write_on_empty
self.parquet_row_group_size = parquet_row_group_size

def execute(self, context: Context):
if self.partition_columns:
Expand Down Expand Up @@ -212,6 +218,15 @@ def convert_types(self, schema, col_type_dict, row) -> list:
for name, value in zip(schema, row)
]

@staticmethod
def _write_rows_to_parquet(parquet_writer: pq.ParquetWriter, rows):
rows_pydic: dict[str, list[Any]] = {col: [] for col in parquet_writer.schema.names}
for row in rows:
for ind, col in enumerate(parquet_writer.schema.names):
rows_pydic[col].append(row[ind])
tbl = pa.Table.from_pydict(rows_pydic, parquet_writer.schema)
parquet_writer.write_table(tbl)

def _write_local_data_files(self, cursor):
"""
Takes a cursor, and writes results to a local file.
Expand All @@ -233,6 +248,7 @@ def _write_local_data_files(self, cursor):
if self.export_format == "parquet":
parquet_schema = self._convert_parquet_schema(cursor)
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
rows_buffer = []

prev_partition_values = None
curr_partition_values = None
Expand All @@ -253,6 +269,10 @@ def _write_local_data_files(self, cursor):
file_no += 1

if self.export_format == "parquet":
# Write out the remaining rows in the buffer
if rows_buffer:
self._write_rows_to_parquet(parquet_writer, rows_buffer)
rows_buffer = []
parquet_writer.close()

file_to_upload["partition_values"] = prev_partition_values
Expand All @@ -279,9 +299,10 @@ def _write_local_data_files(self, cursor):
row = self.convert_types(schema, col_type_dict, row)
if self.null_marker is not None:
row = [value if value is not None else self.null_marker for value in row]
row_pydic = {col: [value] for col, value in zip(schema, row)}
tbl = pa.Table.from_pydict(row_pydic, parquet_schema)
parquet_writer.write_table(tbl)
rows_buffer.append(row)
if len(rows_buffer) >= self.parquet_row_group_size:
self._write_rows_to_parquet(parquet_writer, rows_buffer)
rows_buffer = []
else:
row = self.convert_types(schema, col_type_dict, row)
row_dict = dict(zip(schema, row))
Expand All @@ -301,6 +322,10 @@ def _write_local_data_files(self, cursor):
file_no += 1

if self.export_format == "parquet":
# Write out the remaining rows in the buffer
if rows_buffer:
self._write_rows_to_parquet(parquet_writer, rows_buffer)
rows_buffer = []
parquet_writer.close()

file_to_upload["partition_values"] = curr_partition_values
Expand All @@ -312,6 +337,10 @@ def _write_local_data_files(self, cursor):
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)

if self.export_format == "parquet":
# Write out the remaining rows in the buffer
if rows_buffer:
self._write_rows_to_parquet(parquet_writer, rows_buffer)
rows_buffer = []
parquet_writer.close()
# Last file may have 0 rows, don't yield if empty
# However, if it is the first file and self.write_on_empty is True, then yield to write an empty file
Expand Down Expand Up @@ -349,7 +378,7 @@ def _configure_csv_file(self, file_handle, schema):
csv_writer.writerow(schema)
return csv_writer

def _configure_parquet_file(self, file_handle, parquet_schema):
def _configure_parquet_file(self, file_handle, parquet_schema) -> pq.ParquetWriter:
parquet_writer = pq.ParquetWriter(file_handle.name, parquet_schema)
return parquet_writer

Expand Down
36 changes: 36 additions & 0 deletions tests/providers/google/cloud/transfers/test_sql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,42 @@ def test__write_local_data_files_parquet(self):
df = pd.read_parquet(file.name)
assert df.equals(OUTPUT_DF)

def test__write_local_data_files_parquet_with_row_size(self):
import math

import pyarrow.parquet as pq

op = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
filename=FILENAME,
task_id=TASK_ID,
schema_filename=SCHEMA_FILE,
export_format="parquet",
gzip=False,
schema=SCHEMA,
gcp_conn_id="google_cloud_default",
parquet_row_group_size=8,
)
input_data = INPUT_DATA * 10
output_df = pd.DataFrame([["convert_type_return_value"] * 3] * 30, columns=COLUMNS)

cursor = MagicMock()
cursor.__iter__.return_value = input_data
cursor.description = CURSOR_DESCRIPTION

files = op._write_local_data_files(cursor)
file = next(files)["file_handle"]
file.flush()
df = pd.read_parquet(file.name)
assert df.equals(output_df)
parquet_file = pq.ParquetFile(file.name)
assert parquet_file.num_row_groups == math.ceil((len(INPUT_DATA) * 10) / op.parquet_row_group_size)
tolerance = 1
for i in range(parquet_file.num_row_groups):
row_group_size = parquet_file.metadata.row_group(i).num_rows
assert row_group_size == op.parquet_row_group_size or (tolerance := tolerance - 1) >= 0

def test__write_local_data_files_json_with_exclude_columns(self):
op = DummySQLToGCSOperator(
sql=SQL,
Expand Down

0 comments on commit b502e66

Please sign in to comment.
  翻译: