Skip to content

Commit

Permalink
Upgrade daft and adding two more tests for schema evolution usecase (#…
Browse files Browse the repository at this point in the history
…337)

* Upgrade daft and adding two more tests for schema evolution usecase

* Depending on Ray instead of default as it brings memray which is blocked internally

* bump up version

* Adding ray default in dev-requirements
  • Loading branch information
raghumdani authored Jul 19, 2024
1 parent 56f0aa6 commit a2aadb9
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 5 deletions.
2 changes: 1 addition & 1 deletion deltacat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

deltacat.logs.configure_deltacat_logger(logging.getLogger(__name__))

__version__ = "1.1.11"
__version__ = "1.1.12"


__all__ = [
Expand Down
38 changes: 38 additions & 0 deletions deltacat/tests/utils/test_daft.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,44 @@ def test_read_from_s3_single_column_with_schema_extra_cols(self):
self.assertEqual(table.schema.field("MISSING").type, pa.string())
self.assertEqual(table.num_rows, 100)

def test_read_from_s3_single_column_with_schema_extra_cols_column_names(self):
schema = pa.schema([("a", pa.int8()), ("MISSING", pa.string())])
pa_read_func_kwargs_provider = ReadKwargsProviderPyArrowSchemaOverride(
schema=schema
)
table = daft_s3_file_to_table(
self.MVP_PATH,
content_encoding=ContentEncoding.IDENTITY.value,
content_type=ContentType.PARQUET.value,
column_names=["a", "MISSING"],
pa_read_func_kwargs_provider=pa_read_func_kwargs_provider,
)
self.assertEqual(
table.schema.names, ["a", "MISSING"]
) # NOTE: "MISSING" is padded as a null array
self.assertEqual(table.schema.field("a").type, pa.int8())
self.assertEqual(table.schema.field("MISSING").type, pa.string())
self.assertEqual(table.num_rows, 100)

def test_read_from_s3_single_column_with_schema_only_missing_col(self):
schema = pa.schema([("a", pa.int8()), ("MISSING", pa.string())])
pa_read_func_kwargs_provider = ReadKwargsProviderPyArrowSchemaOverride(
schema=schema
)
table = daft_s3_file_to_table(
self.MVP_PATH,
content_encoding=ContentEncoding.IDENTITY.value,
content_type=ContentType.PARQUET.value,
include_columns=["MISSING"],
column_names=["a", "MISSING"],
pa_read_func_kwargs_provider=pa_read_func_kwargs_provider,
)
self.assertEqual(
table.schema.names, ["MISSING"]
) # NOTE: "MISSING" is padded as a null array
self.assertEqual(table.schema.field("MISSING").type, pa.string())
self.assertEqual(table.num_rows, 0)

def test_read_from_s3_single_column_with_row_groups(self):

metadata = pq.read_metadata(self.MVP_PATH)
Expand Down
63 changes: 63 additions & 0 deletions deltacat/tests/utils/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,43 @@ def test_s3_partial_parquet_file_to_table_when_schema_passed(self):
self.assertEqual(result_schema.field(2).type, "int64")
self.assertEqual(result_schema.field(2).name, "MISSING")

def test_s3_partial_parquet_file_to_table_when_schema_missing_columns(self):

pq_file = ParquetFile(PARQUET_FILE_PATH)
partial_parquet_params = PartialParquetParameters.of(
pq_metadata=pq_file.metadata
)
# only first row group to be downloaded
partial_parquet_params.row_groups_to_download.pop()

schema = pa.schema(
[
pa.field("n_legs", pa.string()),
pa.field("animal", pa.string()),
# NOTE: This field is not in the parquet file, but will be added on as an all-null column
pa.field("MISSING", pa.int64()),
]
)

pa_kwargs_provider = lambda content_type, kwargs: {"schema": schema}

result = s3_partial_parquet_file_to_table(
PARQUET_FILE_PATH,
ContentType.PARQUET.value,
ContentEncoding.IDENTITY.value,
pa_read_func_kwargs_provider=pa_kwargs_provider,
partial_file_download_params=partial_parquet_params,
column_names=["n_legs", "animal", "MISSING"],
include_columns=["MISSING"],
)

self.assertEqual(len(result), 0)
self.assertEqual(len(result.column_names), 1)

result_schema = result.schema
self.assertEqual(result_schema.field(0).type, "int64")
self.assertEqual(result_schema.field(0).name, "MISSING")

def test_s3_partial_parquet_file_to_table_when_schema_passed_with_include_columns(
self,
):
Expand Down Expand Up @@ -234,6 +271,32 @@ def test_read_csv_when_column_names_partial(self):
lambda: pyarrow_read_csv(NON_EMPTY_VALID_UTSV_PATH, **kwargs),
)

def test_read_csv_when_excess_columns_included(self):

schema = pa.schema(
[
("is_active", pa.string()),
("ship_datetime_utc", pa.timestamp("us")),
("MISSING", pa.string()),
]
)
kwargs = content_type_to_reader_kwargs(ContentType.UNESCAPED_TSV.value)
_add_column_kwargs(
ContentType.UNESCAPED_TSV.value,
["is_active", "ship_datetime_utc", "MISSING"],
["is_active", "ship_datetime_utc", "MISSING"],
kwargs,
)

read_kwargs_provider = ReadKwargsProviderPyArrowSchemaOverride(schema=schema)

kwargs = read_kwargs_provider(ContentType.UNESCAPED_TSV.value, kwargs)

self.assertRaises(
pa.lib.ArrowInvalid,
lambda: pyarrow_read_csv(NON_EMPTY_VALID_UTSV_PATH, **kwargs),
)

def test_read_csv_when_empty_csv_sanity(self):

schema = pa.schema(
Expand Down
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ pre-commit == 2.20.0
pytest == 7.2.0
pytest-cov == 4.0.0
pytest-mock == 3.14.0
ray[default] >= 2.20.0
requests-mock == 1.11.0
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
# any changes here should also be reflected in setup.py "install_requires"
aws-embedded-metrics == 3.2.0
boto3 ~= 1.34
getdaft==0.2.29
getdaft==0.2.31
numpy == 1.21.5
pandas == 1.3.5
pyarrow == 12.0.1
pydantic == 1.10.4
pymemcache == 4.0.0
ray[default] >= 2.20.0
ray >= 2.20.0
redis == 4.6.0
s3fs == 2024.5.0
schedule == 1.2.0
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ def find_version(*paths):
"pandas == 1.3.5",
"pyarrow == 12.0.1",
"pydantic == 1.10.4",
"ray[default] >= 2.20.0",
"ray >= 2.20.0",
"s3fs == 2024.5.0",
"tenacity == 8.1.0",
"typing-extensions == 4.4.0",
"pymemcache == 4.0.0",
"redis == 4.6.0",
"getdaft == 0.2.29",
"getdaft == 0.2.31",
"schedule == 1.2.0",
],
setup_requires=["wheel"],
Expand Down

0 comments on commit a2aadb9

Please sign in to comment.