diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 462374d6..3fc11d91 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,3 +29,5 @@ jobs: uv tool install --python-preference only-managed --python 3.12 tox --with tox-uv - name: Test with tox run: uvx --with tox-uv tox run -e ${{ matrix.tox-env }} + - name: Test with tox for polars extra + run: uvx --with tox-uv tox run -e ${{ matrix.tox-env }}-polars diff --git a/docs/for_polars.rst b/docs/for_polars.rst new file mode 100644 index 00000000..327d67fd --- /dev/null +++ b/docs/for_polars.rst @@ -0,0 +1,12 @@ +For Pandas +========== + +Gokart also has features for Polars. It is enabled by installing extra using following command: + +.. code:: sh + + pip install gokart[polars] + + +You need to set the environment variable ``GOKART_DATAFRAME_FRAMEWORK`` as ``polars`` and you can use Polars for the most of the file format used in :func:`~gokart.task.TaskOnKart.load` and :func:`~gokart.task.TaskOnKart.dump` feature. +If you don't set ``GOKART_DATAFRAME_FRAMEWORK`` or set it as ``pandas``, you can use pandas for it. diff --git a/gokart/file_processor.py b/gokart/file_processor.py index ba468000..e0001e85 100644 --- a/gokart/file_processor.py +++ b/gokart/file_processor.py @@ -12,7 +12,6 @@ import luigi.format import numpy as np import pandas as pd -import pandas.errors from luigi.format import TextFormat from gokart.object_storage import ObjectStorage @@ -21,6 +20,15 @@ logger = getLogger(__name__) +DATAFRAME_FRAMEWORK = os.getenv('GOKART_DATAFRAME_FRAMEWORK', 'pandas') +if DATAFRAME_FRAMEWORK == 'polars': + try: + import polars as pl + + except ImportError as e: + raise ValueError('please install polars to use polars as a framework of dataframe for gokart') from e + + class FileProcessor: @abstractmethod def format(self): @@ -131,6 +139,24 @@ def __init__(self, sep=',', encoding: str = 'utf-8'): def format(self): return TextFormat(encoding=self._encoding) + def load(self, file): ... + + def dump(self, obj, file): ... + + +class PolarsCsvFileProcessor(CsvFileProcessor): + def load(self, file): + try: + return pl.read_csv(file, separator=self._sep, encoding=self._encoding) + except pl.exceptions.NoDataError: + return pl.DataFrame() + + def dump(self, obj, file): + assert isinstance(obj, (pl.DataFrame, pl.Series)), f'requires pl.DataFrame or pl.Series, but {type(obj)} is passed.' + obj.write_csv(file, separator=self._sep, include_header=True) + + +class PandasCsvFileProcessor(CsvFileProcessor): def load(self, file): try: return pd.read_csv(file, sep=self._sep, encoding=self._encoding) @@ -164,6 +190,34 @@ def __init__(self, orient: str | None = None): def format(self): return luigi.format.Nop + def load(self, file): ... + + def dump(self, obj, file): ... + + +class PolarsJsonFileProcessor(JsonFileProcessor): + def load(self, file): + try: + if self._orient == 'records': + return pl.read_ndjson(file) + return pl.read_json(file) + except pl.exceptions.ComputeError: + return pl.DataFrame() + + def dump(self, obj, file): + assert isinstance(obj, pl.DataFrame) or isinstance(obj, pl.Series) or isinstance(obj, dict), ( + f'requires pl.DataFrame or pl.Series or dict, but {type(obj)} is passed.' + ) + if isinstance(obj, dict): + obj = pl.from_dict(obj) + + if self._orient == 'records': + obj.write_ndjson(file) + else: + obj.write_json(file) + + +class PandasJsonFileProcessor(JsonFileProcessor): def load(self, file): try: return pd.read_json(file, orient=self._orient, lines=True if self._orient == 'records' else False) @@ -215,11 +269,27 @@ def __init__(self, engine='pyarrow', compression=None): def format(self): return luigi.format.Nop + def load(self, file): ... + + def dump(self, obj, file): ... + + +class PolarsParquetFileProcessor(ParquetFileProcessor): + def load(self, file): + if ObjectStorage.is_buffered_reader(file): + return pl.read_parquet(file.name) + else: + return pl.read_parquet(BytesIO(file.read())) + + def dump(self, obj, file): + assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.' + use_pyarrow = self._engine == 'pyarrow' + compression = 'uncompressed' if self._compression is None else self._compression + obj.write_parquet(file, use_pyarrow=use_pyarrow, compression=compression) + + +class PandasParquetFileProcessor(ParquetFileProcessor): def load(self, file): - # FIXME(mamo3gr): enable streaming (chunked) read with S3. - # pandas.read_parquet accepts file-like object - # but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method, - # which is needed for pandas to read a file in chunks. if ObjectStorage.is_buffered_reader(file): return pd.read_parquet(file.name) else: @@ -240,6 +310,27 @@ def __init__(self, store_index_in_feather: bool): def format(self): return luigi.format.Nop + def load(self, file): ... + + def dump(self, obj, file): ... + + +class PolarsFeatherFileProcessor(FeatherFileProcessor): + def load(self, file): + # Since polars' DataFrame doesn't have index, just load feather file + # TODO: Fix ingnoring store_index_in_feather variable + # Currently in PolarsFeatherFileProcessor, we ignored store_index_in_feather variable to avoid + # a breaking change of FeatherFileProcessor's default behavior. + if ObjectStorage.is_buffered_reader(file): + return pl.read_ipc(file.name) + return pl.read_ipc(BytesIO(file.read())) + + def dump(self, obj, file): + assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.' + obj.write_ipc(file.name) + + +class PandasFeatherFileProcessor(FeatherFileProcessor): def load(self, file): # FIXME(mamo3gr): enable streaming (chunked) read with S3. # pandas.read_feather accepts file-like object @@ -281,6 +372,18 @@ def dump(self, obj, file): dump_obj.to_feather(file.name) +if DATAFRAME_FRAMEWORK == 'polars': + CsvFileProcessor = PolarsCsvFileProcessor # type: ignore + JsonFileProcessor = PolarsJsonFileProcessor # type: ignore + ParquetFileProcessor = PolarsParquetFileProcessor # type: ignore + FeatherFileProcessor = PolarsFeatherFileProcessor # type: ignore +else: + CsvFileProcessor = PandasCsvFileProcessor # type: ignore + JsonFileProcessor = PandasJsonFileProcessor # type: ignore + ParquetFileProcessor = PandasParquetFileProcessor # type: ignore + FeatherFileProcessor = PandasFeatherFileProcessor # type: ignore + + def make_file_processor(file_path: str, store_index_in_feather: bool) -> FileProcessor: extension2processor = { '.txt': TextFileProcessor(), diff --git a/pyproject.toml b/pyproject.toml index 92c6f335..9e60be7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,13 @@ Homepage = "https://github.com/m3dev/gokart" Repository = "https://github.com/m3dev/gokart" Documentation = "https://gokart.readthedocs.io/en/latest/" +[project.optional-dependencies] +polars = [ + # polars doesn't run correctly on Apple silicon chips + "polars-lts-cpu; platform_system =='Darwin'", + "polars; platform_system != 'Darwin'", +] + [dependency-groups] test = [ "fakeredis", diff --git a/test/test_file_processor.py b/test/test_file_processor.py index 7832dd6e..9a609f2f 100644 --- a/test/test_file_processor.py +++ b/test/test_file_processor.py @@ -6,126 +6,26 @@ from typing import Callable import boto3 -import pandas as pd -import pytest from luigi import LocalTarget from moto import mock_aws from gokart.file_processor import ( + DATAFRAME_FRAMEWORK, CsvFileProcessor, FeatherFileProcessor, GzipFileProcessor, JsonFileProcessor, NpzFileProcessor, + PandasCsvFileProcessor, ParquetFileProcessor, PickleFileProcessor, + PolarsCsvFileProcessor, TextFileProcessor, make_file_processor, ) from gokart.object_storage import ObjectStorage -class TestCsvFileProcessor(unittest.TestCase): - def test_dump_csv_with_utf8(self): - df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) - processor = CsvFileProcessor() - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.csv' - - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('w') as f: - processor.dump(df, f) - - # read with utf-8 to check if the file is dumped with utf8 - loaded_df = pd.read_csv(temp_path, encoding='utf-8') - pd.testing.assert_frame_equal(df, loaded_df) - - def test_dump_csv_with_cp932(self): - df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) - processor = CsvFileProcessor(encoding='cp932') - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.csv' - - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('w') as f: - processor.dump(df, f) - - # read with cp932 to check if the file is dumped with cp932 - loaded_df = pd.read_csv(temp_path, encoding='cp932') - pd.testing.assert_frame_equal(df, loaded_df) - - def test_load_csv_with_utf8(self): - df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) - processor = CsvFileProcessor() - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.csv' - df.to_csv(temp_path, encoding='utf-8', index=False) - - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('r') as f: - # read with utf-8 to check if the file is dumped with utf8 - loaded_df = processor.load(f) - pd.testing.assert_frame_equal(df, loaded_df) - - def test_load_csv_with_cp932(self): - df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) - processor = CsvFileProcessor(encoding='cp932') - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.csv' - df.to_csv(temp_path, encoding='cp932', index=False) - - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('r') as f: - # read with cp932 to check if the file is dumped with cp932 - loaded_df = processor.load(f) - pd.testing.assert_frame_equal(df, loaded_df) - - -class TestJsonFileProcessor: - @pytest.mark.parametrize( - 'orient,input_data,expected_json', - [ - pytest.param( - None, - pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}), - '{"A":{"0":1,"1":2,"2":3},"B":{"0":4,"1":5,"2":6}}', - id='With Default Orient for DataFrame', - ), - pytest.param( - 'records', - pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}), - '{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n', - id='With Records Orient for DataFrame', - ), - pytest.param(None, {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":{"0":1,"1":2,"2":3},"B":{"0":4,"1":5,"2":6}}', id='With Default Orient for Dict'), - pytest.param('records', {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n', id='With Records Orient for Dict'), - pytest.param(None, {}, '{}', id='With Default Orient for Empty Dict'), - pytest.param('records', {}, '\n', id='With Records Orient for Empty Dict'), - ], - ) - def test_dump_and_load_json(self, orient, input_data, expected_json): - processor = JsonFileProcessor(orient=orient) - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.json' - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('w') as f: - processor.dump(input_data, f) - with local_target.open('r') as f: - loaded_df = processor.load(f) - f.seek(0) - loaded_json = f.read().decode('utf-8') - - assert loaded_json == expected_json - - df_input = pd.DataFrame(input_data) - pd.testing.assert_frame_equal(df_input, loaded_df) - - class TestPickleFileProcessor(unittest.TestCase): def test_dump_and_load_normal_obj(self): var = 'abc' @@ -189,50 +89,12 @@ def test_dump_and_load_with_readables3file(self): self.assertEqual(loaded, var) -class TestFeatherFileProcessor(unittest.TestCase): - def test_feather_should_return_same_dataframe(self): - df = pd.DataFrame({'a': [1]}) - processor = FeatherFileProcessor(store_index_in_feather=True) - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.feather' - - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('w') as f: - processor.dump(df, f) - - with local_target.open('r') as f: - loaded_df = processor.load(f) - - pd.testing.assert_frame_equal(df, loaded_df) - - def test_feather_should_save_index_name(self): - df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='index_name')) - processor = FeatherFileProcessor(store_index_in_feather=True) - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.feather' - - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('w') as f: - processor.dump(df, f) - - with local_target.open('r') as f: - loaded_df = processor.load(f) - - pd.testing.assert_frame_equal(df, loaded_df) - - def test_feather_should_raise_error_index_name_is_None(self): - df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='None')) - processor = FeatherFileProcessor(store_index_in_feather=True) - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.feather' - - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('w') as f: - with self.assertRaises(AssertionError): - processor.dump(df, f) +class TestFileProcessorClassSelection(unittest.TestCase): + def test_processor_selection(self): + if DATAFRAME_FRAMEWORK == 'polars': + self.assertTrue(issubclass(CsvFileProcessor, PolarsCsvFileProcessor)) + else: + self.assertTrue(issubclass(CsvFileProcessor, PandasCsvFileProcessor)) class TestMakeFileProcessor(unittest.TestCase): diff --git a/test/test_file_processor_pandas.py b/test/test_file_processor_pandas.py new file mode 100644 index 00000000..c2a519e5 --- /dev/null +++ b/test/test_file_processor_pandas.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import importlib.util +import tempfile + +import pandas as pd +import pytest +from luigi import LocalTarget + +from gokart.file_processor import PandasCsvFileProcessor, PandasFeatherFileProcessor, PandasJsonFileProcessor + +polars_installed = importlib.util.find_spec('polars') is not None +pytestmark = pytest.mark.skipif(polars_installed, reason='polars installed, skip pandas tests') + + +def test_dump_csv_with_utf8(): + df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) + processor = PandasCsvFileProcessor() + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.csv' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + # read with utf-8 to check if the file is dumped with utf8 + loaded_df = pd.read_csv(temp_path, encoding='utf-8') + pd.testing.assert_frame_equal(df, loaded_df) + + +def test_dump_csv_with_cp932(): + df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) + processor = PandasCsvFileProcessor(encoding='cp932') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.csv' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + # read with cp932 to check if the file is dumped with cp932 + loaded_df = pd.read_csv(temp_path, encoding='cp932') + pd.testing.assert_frame_equal(df, loaded_df) + + +def test_load_csv_with_utf8(): + df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) + processor = PandasCsvFileProcessor() + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.csv' + df.to_csv(temp_path, encoding='utf-8', index=False) + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('r') as f: + # read with utf-8 to check if the file is dumped with utf8 + loaded_df = processor.load(f) + pd.testing.assert_frame_equal(df, loaded_df) + + +def test_load_csv_with_cp932(): + df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) + processor = PandasCsvFileProcessor(encoding='cp932') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.csv' + df.to_csv(temp_path, encoding='cp932', index=False) + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('r') as f: + # read with cp932 to check if the file is dumped with cp932 + loaded_df = processor.load(f) + pd.testing.assert_frame_equal(df, loaded_df) + + +@pytest.mark.parametrize( + 'orient,input_data,expected_json', + [ + pytest.param( + None, + pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}), + '{"A":{"0":1,"1":2,"2":3},"B":{"0":4,"1":5,"2":6}}', + id='With Default Orient for DataFrame', + ), + pytest.param( + 'records', + pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}), + '{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n', + id='With Records Orient for DataFrame', + ), + pytest.param(None, {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":{"0":1,"1":2,"2":3},"B":{"0":4,"1":5,"2":6}}', id='With Default Orient for Dict'), + pytest.param('records', {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n', id='With Records Orient for Dict'), + pytest.param(None, {}, '{}', id='With Default Orient for Empty Dict'), + pytest.param('records', {}, '\n', id='With Records Orient for Empty Dict'), + ], +) +def test_dump_and_load_json(orient, input_data, expected_json): + processor = PandasJsonFileProcessor(orient=orient) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.json' + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(input_data, f) + with local_target.open('r') as f: + loaded_df = processor.load(f) + f.seek(0) + loaded_json = f.read().decode('utf-8') + + assert loaded_json == expected_json + + df_input = pd.DataFrame(input_data) + pd.testing.assert_frame_equal(df_input, loaded_df) + + +def test_feather_should_return_same_dataframe(): + df = pd.DataFrame({'a': [1]}) + processor = PandasFeatherFileProcessor(store_index_in_feather=True) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.feather' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + with local_target.open('r') as f: + loaded_df = processor.load(f) + + pd.testing.assert_frame_equal(df, loaded_df) + + +def test_feather_should_save_index_name(): + df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='index_name')) + processor = PandasFeatherFileProcessor(store_index_in_feather=True) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.feather' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + with local_target.open('r') as f: + loaded_df = processor.load(f) + + pd.testing.assert_frame_equal(df, loaded_df) + + +def test_feather_should_raise_error_index_name_is_None(): + df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='None')) + processor = PandasFeatherFileProcessor(store_index_in_feather=True) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.feather' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + with pytest.raises(AssertionError): + processor.dump(df, f) diff --git a/test/test_file_processor_polars.py b/test/test_file_processor_polars.py new file mode 100644 index 00000000..5508b977 --- /dev/null +++ b/test/test_file_processor_polars.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import importlib +import tempfile + +import pytest +from luigi import LocalTarget + +from gokart import file_processor +from gokart.file_processor import PolarsCsvFileProcessor, PolarsFeatherFileProcessor, PolarsJsonFileProcessor + +pl = pytest.importorskip('polars', reason='polars required') +pl_testing = pytest.importorskip('polars.testing', reason='polars required') + + +@pytest.fixture +def reload_processor(monkeypatch): + """ + A pytest fixture that reloads the `gokart.file_processor` module after modifying + the environment variable `GOKART_DATAFRAME_FRAMEWORK`. This ensures that polars + is used when reloading the module. + + Returns: + Tuple[Type[PolarsCsvFileProcessor], Type[PolarsFeatherFileProcessor], Type[PolarsJsonFileProcessor]]: + The reloaded classes from the `gokart.file_processor` module. + """ + monkeypatch.setenv('GOKART_DATAFRAME_FRAMEWORK', 'polars') + importlib.reload(file_processor) + + yield PolarsCsvFileProcessor, PolarsFeatherFileProcessor, PolarsJsonFileProcessor + + +def test_dump_csv(reload_processor): + PolarsCsvFileProcessor, _, _ = reload_processor + df = pl.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) + processor = PolarsCsvFileProcessor() + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.csv' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + # read with utf-8 to check if the file is dumped with utf8 + loaded_df = pl.read_csv(temp_path) + pl_testing.assert_frame_equal(df, loaded_df) + + +def test_load_csv(reload_processor): + PolarsCsvFileProcessor, _, _ = reload_processor + df = pl.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) + processor = PolarsCsvFileProcessor() + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.csv' + df.write_csv(temp_path, include_header=True) + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('r') as f: + # read with utf-8 to check if the file is dumped with utf8 + loaded_df = processor.load(f) + pl_testing.assert_frame_equal(df, loaded_df) + + +@pytest.mark.parametrize( + 'orient,input_data,expected_json', + [ + pytest.param( + None, + pl.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}), + '[{"A":1,"B":4},{"A":2,"B":5},{"A":3,"B":6}]', + id='With Default Orient for DataFrame', + ), + pytest.param( + 'records', + pl.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}), + '{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n', + id='With Records Orient for DataFrame', + ), + pytest.param(None, {'A': [1, 2, 3], 'B': [4, 5, 6]}, '[{"A":1,"B":4},{"A":2,"B":5},{"A":3,"B":6}]', id='With Default Orient for Dict'), + pytest.param('records', {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n', id='With Records Orient for Dict'), + pytest.param(None, {}, '[]', id='With Default Orient for Empty Dict'), + pytest.param('records', {}, '', id='With Records Orient for Empty Dict'), + ], +) +def test_dump_and_load_json(reload_processor, orient, input_data, expected_json): + _, _, PolarsJsonFileProcessor = reload_processor + processor = PolarsJsonFileProcessor(orient=orient) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.json' + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(input_data, f) + with local_target.open('r') as f: + loaded_df = processor.load(f) + f.seek(0) + loaded_json = f.read().decode('utf-8') + + assert loaded_json == expected_json + + df_input = pl.DataFrame(input_data) + pl_testing.assert_frame_equal(df_input, loaded_df) + + +def test_feather_should_return_same_dataframe(reload_processor): + _, PolarsFeatherFileProcessor, _ = reload_processor + df = pl.DataFrame({'a': [1]}) + # TODO: currently we set store_index_in_feather True but it is ignored + processor = PolarsFeatherFileProcessor(store_index_in_feather=True) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.feather' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + with local_target.open('r') as f: + loaded_df = processor.load(f) + + pl_testing.assert_frame_equal(df, loaded_df) diff --git a/test/test_target.py b/test/test_target.py index 2d82e76f..9d82f721 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -1,3 +1,4 @@ +import importlib.util import io import os import shutil @@ -8,6 +9,7 @@ import boto3 import numpy as np import pandas as pd +import pytest from matplotlib import pyplot from moto import mock_aws @@ -15,6 +17,9 @@ from gokart.target import make_model_target, make_target from test.util import _get_temporary_directory +polars_installed = importlib.util.find_spec('polars') is not None +pytestmark = pytest.mark.skipif(polars_installed, reason='polars installed, skip pandas tests') + class LocalTargetTest(unittest.TestCase): def setUp(self): diff --git a/tox.ini b/tox.ini index 9fdb4765..604aba14 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,7 @@ [tox] envlist = py{39,310,311,312,313},ruff,mypy +labels = + polars = py{39,310,311,312,313}-polars skipsdist = True [testenv] @@ -7,6 +9,8 @@ runner = uv-venv-lock-runner dependency_groups = test commands = {envpython} -m pytest --cov=gokart --cov-report=xml -vv {posargs:} +extras = + polars: polars [testenv:ruff] dependency_groups = lint diff --git a/uv.lock b/uv.lock index 1da8c1a0..da9f8cc0 100644 --- a/uv.lock +++ b/uv.lock @@ -616,6 +616,12 @@ dependencies = [ { name = "uritemplate" }, ] +[package.optional-dependencies] +polars = [ + { name = "polars", marker = "sys_platform != 'darwin'" }, + { name = "polars-lts-cpu", marker = "sys_platform == 'darwin'" }, +] + [package.dev-dependencies] lint = [ { name = "mypy" }, @@ -648,6 +654,8 @@ requires-dist = [ { name = "luigi" }, { name = "numpy" }, { name = "pandas" }, + { name = "polars", marker = "sys_platform != 'darwin' and extra == 'polars'" }, + { name = "polars-lts-cpu", marker = "sys_platform == 'darwin' and extra == 'polars'" }, { name = "pyarrow" }, { name = "redis" }, { name = "slack-sdk" }, @@ -1632,6 +1640,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, ] +[[package]] +name = "polars" +version = "1.25.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/56/d8a13c3a1990c92cc2c4f1887e97ea15aabf5685b1e826f875ca3e4e6c9e/polars-1.25.2.tar.gz", hash = "sha256:c6bd9b1b17c86e49bcf8aac44d2238b77e414d7df890afc3924812a5c989a4fe", size = 4501858 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cd/92/db411b7c83f694dca1b8348fa57a120c27c67cf622b85fa88c7ecf463adb/polars-1.25.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7fcbb4f476784384ccda48757fca4e8c2e2c5a0a3aef3717aaf56aee4e30e09", size = 35121263 }, + { url = "https://files.pythonhosted.org/packages/9f/a5/5ff200ce3bc643d5f12d91eddb9720fa083267c45fe395bcf0046e97cc2d/polars-1.25.2-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:9dd91885c9ee5ffad8725c8591f73fb7bd2632c740277ee641f0453176b3d4b8", size = 32254697 }, + { url = "https://files.pythonhosted.org/packages/70/d5/7a5458d05d5a0af816b1c7034aa1d026b7b8176a8de41e96dac70fcf29e2/polars-1.25.2-cp39-abi3-win_amd64.whl", hash = "sha256:a547796643b9a56cb2959be87d7cb87ff80a5c8ae9367f32fe1ad717039e9afc", size = 35318381 }, + { url = "https://files.pythonhosted.org/packages/24/df/60d35c4ae8ec357a5fb9914eb253bd1bad9e0f5332eda2bd2c6371dd3668/polars-1.25.2-cp39-abi3-win_arm64.whl", hash = "sha256:a2488e9d4b67bf47b18088f7264999180559e6ec2637ed11f9d0d4f98a74a37c", size = 31619833 }, +] + +[[package]] +name = "polars-lts-cpu" +version = "1.25.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ea/ac/399f210dd334e70eb4a0df25864a97512fa61bea352bdd8a6285c4fbbb63/polars_lts_cpu-1.25.2.tar.gz", hash = "sha256:caf4764ceb94457f96166af16d060311bd742cf20a850a8651eac0643a7f41dd", size = 4501478 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/39/0cb03a21d38bc152e09046437990fbcd1973c326c24ce15db88f0a028760/polars_lts_cpu-1.25.2-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:09a49dbf9782f72ed5cbf3369f26349d3ba98b209688dd8748ac2a1ae14bcf73", size = 34139766 }, + { url = "https://files.pythonhosted.org/packages/29/41/4417553a00e3f0e696b747df34cd4546fe261da2bf99f7ffd88e0d5071c7/polars_lts_cpu-1.25.2-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:e230385b6b347465196343cd6f167451ca9bbc3a382e9db5f231977d5ab59f00", size = 31327093 }, +] + [[package]] name = "proto-plus" version = "1.26.0"