From 45623d19489ade03ef4e50369561428acfe19fd7 Mon Sep 17 00:00:00 2001 From: hirosassa Date: Tue, 11 Mar 2025 06:59:22 +0900 Subject: [PATCH 01/13] temp implementation --- gokart/file_processor.py | 198 +++++++++++++++++++++++++++------------ pyproject.toml | 4 + 2 files changed, 143 insertions(+), 59 deletions(-) diff --git a/gokart/file_processor.py b/gokart/file_processor.py index 21177ff1..3430354c 100644 --- a/gokart/file_processor.py +++ b/gokart/file_processor.py @@ -9,8 +9,6 @@ import luigi.contrib.s3 import luigi.format import numpy as np -import pandas as pd -import pandas.errors from luigi.format import TextFormat from gokart.object_storage import ObjectStorage @@ -131,13 +129,31 @@ def format(self): def load(self, file): try: - return pd.read_csv(file, sep=self._sep, encoding=self._encoding) - except pd.errors.EmptyDataError: - return pd.DataFrame() + import pandas as pd + + try: + return pd.read_csv(file, sep=self._sep, encoding=self._encoding) + except pd.errors.EmptyDataError: + return pd.DataFrame() + except ImportError: + import polars as pl + + try: + return pl.read_csv(file, sep=self._sep, encoding=self._encoding) + except pl.exceptions.NoDataError: + return pd.DataFrame() def dump(self, obj, file): - assert isinstance(obj, (pd.DataFrame, pd.Series)), f'requires pd.DataFrame or pd.Series, but {type(obj)} is passed.' - obj.to_csv(file, mode='wt', index=False, sep=self._sep, header=True, encoding=self._encoding) + try: + import pandas as pd + + assert isinstance(obj, (pd.DataFrame, pd.Series)), f'requires pd.DataFrame or pd.Series, but {type(obj)} is passed.' + obj.to_csv(file, mode='wt', index=False, sep=self._sep, header=True, encoding=self._encoding) + except ImportError: + import polars as pl + + assert isinstance(obj, (pl.DataFrame, pl.Series)), f'requires pl.DataFrame or pl.Series, but {type(obj)} is passed.' + obj.write_csv(file, separator=self._sep, include_header=True) class GzipFileProcessor(FileProcessor): @@ -161,17 +177,39 @@ def format(self): def load(self, file): try: - return pd.read_json(file) - except pd.errors.EmptyDataError: - return pd.DataFrame() + import pandas as pd + + try: + return self.read_json(file) + except pd.errors.EmptyDataError: + return pd.DataFrame() + except ImportError: + import polars as pl + + try: + return self.read_json(file) + except pl.exceptions.NoDataError: + return pl.DataFrame() def dump(self, obj, file): - assert isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series) or isinstance(obj, dict), ( - f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.' - ) - if isinstance(obj, dict): - obj = pd.DataFrame.from_dict(obj) - obj.to_json(file) + try: + import pandas as pd + + assert isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series) or isinstance(obj, dict), ( + f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.' + ) + if isinstance(obj, dict): + obj = pd.DataFrame.from_dict(obj) + obj.to_json(file) + except ImportError: + import polars as pl + + assert isinstance(obj, pl.DataFrame) or isinstance(obj, pl.Series) or isinstance(obj, dict), ( + f'requires pl.DataFrame or pl.Series or dict, but {type(obj)} is passed.' + ) + if isinstance(obj, dict): + obj = pl.from_dict(obj) + obj.write_json(file) class XmlFileProcessor(FileProcessor): @@ -211,19 +249,39 @@ def format(self): return luigi.format.Nop def load(self, file): - # FIXME(mamo3gr): enable streaming (chunked) read with S3. - # pandas.read_parquet accepts file-like object - # but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method, - # which is needed for pandas to read a file in chunks. - if ObjectStorage.is_buffered_reader(file): - return pd.read_parquet(file.name) - else: - return pd.read_parquet(BytesIO(file.read())) + try: + import pandas as pd + + # FIXME(mamo3gr): enable streaming (chunked) read with S3. + # pandas.read_parquet accepts file-like object + # but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method, + # which is needed for pandas to read a file in chunks. + if ObjectStorage.is_buffered_reader(file): + return pd.read_parquet(file.name) + else: + return pd.read_parquet(BytesIO(file.read())) + except ImportError: + import polars as pl + + if ObjectStorage.is_buffered_reader(file): + return pl.read_parquet(file.name) + else: + return pl.read_parquet(BytesIO(file.read())) def dump(self, obj, file): - assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.' - # MEMO: to_parquet only supports a filepath as string (not a file handle) - obj.to_parquet(file.name, index=False, engine=self._engine, compression=self._compression) + try: + import pandas as pd + + assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.' + # MEMO: to_parquet only supports a filepath as string (not a file handle) + obj.to_parquet(file.name, index=False, engine=self._engine, compression=self._compression) + except ImportError: + import polars as pl + + assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.' + use_pyarrow = self._engine == 'pyarrow' + compression = 'uncompressed' if self._compression is None else self._compression + obj.write_parquet(file, use_pyarrow=use_pyarrow, compression=compression) class FeatherFileProcessor(FileProcessor): @@ -236,44 +294,66 @@ def format(self): return luigi.format.Nop def load(self, file): - # FIXME(mamo3gr): enable streaming (chunked) read with S3. - # pandas.read_feather accepts file-like object - # but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method, - # which is needed for pandas to read a file in chunks. - if ObjectStorage.is_buffered_reader(file): - loaded_df = pd.read_feather(file.name) - else: - loaded_df = pd.read_feather(BytesIO(file.read())) - - if self._store_index_in_feather: - if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns): - index_columns = [col_name for col_name in loaded_df.columns[::-1] if col_name[: len(self.INDEX_COLUMN_PREFIX)] == self.INDEX_COLUMN_PREFIX] - index_column = index_columns[0] - index_name = index_column[len(self.INDEX_COLUMN_PREFIX) :] - if index_name == 'None': - index_name = None - loaded_df.index = pd.Index(loaded_df[index_column].values, name=index_name) - loaded_df = loaded_df.drop(columns={index_column}) - - return loaded_df + try: + import pandas as pd + + # FIXME(mamo3gr): enable streaming (chunked) read with S3. + # pandas.read_feather accepts file-like object + # but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method, + # which is needed for pandas to read a file in chunks. + if ObjectStorage.is_buffered_reader(file): + loaded_df = pd.read_feather(file.name) + else: + loaded_df = pd.read_feather(BytesIO(file.read())) + + if self._store_index_in_feather: + if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns): + index_columns = [col_name for col_name in loaded_df.columns[::-1] if col_name[: len(self.INDEX_COLUMN_PREFIX)] == self.INDEX_COLUMN_PREFIX] + index_column = index_columns[0] + index_name = index_column[len(self.INDEX_COLUMN_PREFIX) :] + if index_name == 'None': + index_name = None + loaded_df.index = pd.Index(loaded_df[index_column].values, name=index_name) + loaded_df = loaded_df.drop(columns={index_column}) + + return loaded_df + except ImportError: + import polars as pl + + # Since polars' DataFrame doesn't have index, just load feather file + if ObjectStorage.is_buffered_reader(file): + loaded_df = pl.read_ipc(file.name) + else: + loaded_df = pl.read_ipc(BytesIO(file.read())) + + return loaded_df def dump(self, obj, file): - assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.' - dump_obj = obj.copy() + try: + import pandas as pd - if self._store_index_in_feather: - index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}' - assert index_column_name not in dump_obj.columns, ( - f'column name {index_column_name} already exists in dump_obj. \ + assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.' + dump_obj = obj.copy() + + if self._store_index_in_feather: + index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}' + assert index_column_name not in dump_obj.columns, ( + f'column name {index_column_name} already exists in dump_obj. \ Consider not saving index by setting store_index_in_feather=False.' - ) - assert dump_obj.index.name != 'None', 'index name is "None", which is not allowed in gokart. Consider setting another index name.' + ) + assert dump_obj.index.name != 'None', 'index name is "None", which is not allowed in gokart. Consider setting another index name.' + + dump_obj[index_column_name] = dump_obj.index + dump_obj = dump_obj.reset_index(drop=True) - dump_obj[index_column_name] = dump_obj.index - dump_obj = dump_obj.reset_index(drop=True) + # to_feather supports "binary" file-like object, but file variable is text + dump_obj.to_feather(file.name) + except ImportError: + import polars as pl - # to_feather supports "binary" file-like object, but file variable is text - dump_obj.to_feather(file.name) + assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.' + dump_obj = obj.copy() + dump_obj.write_ipc(file.name) def make_file_processor(file_path: str, store_index_in_feather: bool) -> FileProcessor: diff --git a/pyproject.toml b/pyproject.toml index 8afd9235..b71ab103 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,10 @@ Homepage = "https://github.com/m3dev/gokart" Repository = "https://github.com/m3dev/gokart" Documentation = "https://gokart.readthedocs.io/en/latest/" +[project.optional-dependencies] +pandas = ["pandas"] +polars = ["polars"] + [dependency-groups] test = [ "fakeredis", From c62b9b04da87a46a03b7c7947cba58d8dfc4c0a7 Mon Sep 17 00:00:00 2001 From: hirosassa Date: Sun, 16 Mar 2025 08:22:49 +0900 Subject: [PATCH 02/13] add DATAFRAME_FRAMEWORK variable to branch pandas and polars --- gokart/file_processor.py | 285 +++++++++++++++++++++------------------ uv.lock | 24 ++++ 2 files changed, 177 insertions(+), 132 deletions(-) diff --git a/gokart/file_processor.py b/gokart/file_processor.py index 3430354c..602f65bd 100644 --- a/gokart/file_processor.py +++ b/gokart/file_processor.py @@ -17,6 +17,16 @@ logger = getLogger(__name__) +try: + import polars as pl + + DATAFRAME_FRAMEWORK = 'polars' +except ImportError: + import pandas as pd + + DATAFRAME_FRAMEWORK = 'pandas' + + class FileProcessor(object): @abstractmethod def format(self): @@ -122,38 +132,39 @@ class CsvFileProcessor(FileProcessor): def __init__(self, sep=',', encoding: str = 'utf-8'): self._sep = sep self._encoding = encoding - super(CsvFileProcessor, self).__init__() + super().__init__() def format(self): return TextFormat(encoding=self._encoding) def load(self, file): - try: - import pandas as pd + ... - try: - return pd.read_csv(file, sep=self._sep, encoding=self._encoding) - except pd.errors.EmptyDataError: - return pd.DataFrame() - except ImportError: - import polars as pl + def dump(self, obj, file): + ... - try: - return pl.read_csv(file, sep=self._sep, encoding=self._encoding) - except pl.exceptions.NoDataError: - return pd.DataFrame() +class PolarsCsvFileProcessor(CsvFileProcessor): + def load(self, file): + try: + return pl.read_csv(file, sep=self._sep, encoding=self._encoding) + except pl.exceptions.NoDataError: + return pl.DataFrame() def dump(self, obj, file): - try: - import pandas as pd + assert isinstance(obj, (pl.DataFrame, pl.Series)), f'requires pl.DataFrame or pl.Series, but {type(obj)} is passed.' + obj.write_csv(file, separator=self._sep, include_header=True) + - assert isinstance(obj, (pd.DataFrame, pd.Series)), f'requires pd.DataFrame or pd.Series, but {type(obj)} is passed.' - obj.to_csv(file, mode='wt', index=False, sep=self._sep, header=True, encoding=self._encoding) - except ImportError: - import polars as pl +class PandasCsvFileProcessor(CsvFileProcessor): + def load(self, file): + try: + return pd.read_csv(file, sep=self._sep, encoding=self._encoding) + except pd.errors.EmptyDataError: + return pd.DataFrame() - assert isinstance(obj, (pl.DataFrame, pl.Series)), f'requires pl.DataFrame or pl.Series, but {type(obj)} is passed.' - obj.write_csv(file, separator=self._sep, include_header=True) + def dump(self, obj, file): + assert isinstance(obj, (pd.DataFrame, pd.Series)), f'requires pd.DataFrame or pd.Series, but {type(obj)} is passed.' + obj.to_csv(file, mode='wt', index=False, sep=self._sep, header=True, encoding=self._encoding) class GzipFileProcessor(FileProcessor): @@ -176,40 +187,42 @@ def format(self): return None def load(self, file): - try: - import pandas as pd + ... - try: - return self.read_json(file) - except pd.errors.EmptyDataError: - return pd.DataFrame() - except ImportError: - import polars as pl + def dump(self, obj, file): + ... - try: - return self.read_json(file) - except pl.exceptions.NoDataError: - return pl.DataFrame() + +class PolarsJsonFileProcessor(JsonFileProcessor): + def load(self, file): + try: + return self.read_json(file) + except pl.exceptions.NoDataError: + return pl.DataFrame() def dump(self, obj, file): + assert isinstance(obj, pl.DataFrame) or isinstance(obj, pl.Series) or isinstance(obj, dict), ( + f'requires pl.DataFrame or pl.Series or dict, but {type(obj)} is passed.' + ) + if isinstance(obj, dict): + obj = pl.from_dict(obj) + obj.write_json(file) + + +class PandasJsonFileProcessor(JsonFileProcessor): + def load(self, file): try: - import pandas as pd + return self.read_json(file) + except pd.errors.EmptyDataError: + return pd.DataFrame() - assert isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series) or isinstance(obj, dict), ( - f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.' - ) - if isinstance(obj, dict): - obj = pd.DataFrame.from_dict(obj) - obj.to_json(file) - except ImportError: - import polars as pl - - assert isinstance(obj, pl.DataFrame) or isinstance(obj, pl.Series) or isinstance(obj, dict), ( - f'requires pl.DataFrame or pl.Series or dict, but {type(obj)} is passed.' - ) - if isinstance(obj, dict): - obj = pl.from_dict(obj) - obj.write_json(file) + def dump(self, obj, file): + assert isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series) or isinstance(obj, dict), ( + f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.' + ) + if isinstance(obj, dict): + obj = pd.DataFrame.from_dict(obj) + obj.to_json(file) class XmlFileProcessor(FileProcessor): @@ -243,50 +256,47 @@ class ParquetFileProcessor(FileProcessor): def __init__(self, engine='pyarrow', compression=None): self._engine = engine self._compression = compression - super(ParquetFileProcessor, self).__init__() + super().__init__() def format(self): return luigi.format.Nop def load(self, file): - try: - import pandas as pd - - # FIXME(mamo3gr): enable streaming (chunked) read with S3. - # pandas.read_parquet accepts file-like object - # but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method, - # which is needed for pandas to read a file in chunks. - if ObjectStorage.is_buffered_reader(file): - return pd.read_parquet(file.name) - else: - return pd.read_parquet(BytesIO(file.read())) - except ImportError: - import polars as pl - - if ObjectStorage.is_buffered_reader(file): - return pl.read_parquet(file.name) - else: - return pl.read_parquet(BytesIO(file.read())) + ... def dump(self, obj, file): - try: - import pandas as pd + ... - assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.' - # MEMO: to_parquet only supports a filepath as string (not a file handle) - obj.to_parquet(file.name, index=False, engine=self._engine, compression=self._compression) - except ImportError: - import polars as pl +class PolarsParquetFileProcessor(ParquetFileProcessor): + def load(self, file): + if ObjectStorage.is_buffered_reader(file): + return pl.read_parquet(file.name) + else: + return pl.read_parquet(BytesIO(file.read())) - assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.' - use_pyarrow = self._engine == 'pyarrow' - compression = 'uncompressed' if self._compression is None else self._compression - obj.write_parquet(file, use_pyarrow=use_pyarrow, compression=compression) + def dump(self, obj, file): + assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.' + use_pyarrow = self._engine == 'pyarrow' + compression = 'uncompressed' if self._compression is None else self._compression + obj.write_parquet(file, use_pyarrow=use_pyarrow, compression=compression) + + +class PandasParquetFileProcessor(ParquetFileProcessor): + def load(self, file): + if ObjectStorage.is_buffered_reader(file): + return pd.read_parquet(file.name) + else: + return pd.read_parquet(BytesIO(file.read())) + + def dump(self, obj, file): + assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.' + # MEMO: to_parquet only supports a filepath as string (not a file handle) + obj.to_parquet(file.name, index=False, engine=self._engine, compression=self._compression) class FeatherFileProcessor(FileProcessor): def __init__(self, store_index_in_feather: bool): - super(FeatherFileProcessor, self).__init__() + super().__init__() self._store_index_in_feather = store_index_in_feather self.INDEX_COLUMN_PREFIX = '__feather_gokart_index__' @@ -294,67 +304,78 @@ def format(self): return luigi.format.Nop def load(self, file): - try: - import pandas as pd - - # FIXME(mamo3gr): enable streaming (chunked) read with S3. - # pandas.read_feather accepts file-like object - # but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method, - # which is needed for pandas to read a file in chunks. - if ObjectStorage.is_buffered_reader(file): - loaded_df = pd.read_feather(file.name) - else: - loaded_df = pd.read_feather(BytesIO(file.read())) - - if self._store_index_in_feather: - if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns): - index_columns = [col_name for col_name in loaded_df.columns[::-1] if col_name[: len(self.INDEX_COLUMN_PREFIX)] == self.INDEX_COLUMN_PREFIX] - index_column = index_columns[0] - index_name = index_column[len(self.INDEX_COLUMN_PREFIX) :] - if index_name == 'None': - index_name = None - loaded_df.index = pd.Index(loaded_df[index_column].values, name=index_name) - loaded_df = loaded_df.drop(columns={index_column}) - - return loaded_df - except ImportError: - import polars as pl - - # Since polars' DataFrame doesn't have index, just load feather file - if ObjectStorage.is_buffered_reader(file): - loaded_df = pl.read_ipc(file.name) - else: - loaded_df = pl.read_ipc(BytesIO(file.read())) - - return loaded_df + ... def dump(self, obj, file): - try: - import pandas as pd + ... - assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.' - dump_obj = obj.copy() - if self._store_index_in_feather: - index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}' - assert index_column_name not in dump_obj.columns, ( - f'column name {index_column_name} already exists in dump_obj. \ +class PolarsFeatherFileProcessor(FeatherFileProcessor): + def load(self, file): + # Since polars' DataFrame doesn't have index, just load feather file + if ObjectStorage.is_buffered_reader(file): + loaded_df = pl.read_ipc(file.name) + else: + loaded_df = pl.read_ipc(BytesIO(file.read())) + + def dump(self, obj, file): + assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.' + dump_obj = obj.copy() + dump_obj.write_ipc(file.name) + + +class PandasFeatherFileProcessor(FeatherFileProcessor): + def load(self, file): + # FIXME(mamo3gr): enable streaming (chunked) read with S3. + # pandas.read_feather accepts file-like object + # but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method, + # which is needed for pandas to read a file in chunks. + if ObjectStorage.is_buffered_reader(file): + loaded_df = pd.read_feather(file.name) + else: + loaded_df = pd.read_feather(BytesIO(file.read())) + + if self._store_index_in_feather: + if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns): + index_columns = [col_name for col_name in loaded_df.columns[::-1] if col_name[: len(self.INDEX_COLUMN_PREFIX)] == self.INDEX_COLUMN_PREFIX] + index_column = index_columns[0] + index_name = index_column[len(self.INDEX_COLUMN_PREFIX) :] + if index_name == 'None': + index_name = None + loaded_df.index = pd.Index(loaded_df[index_column].values, name=index_name) + loaded_df = loaded_df.drop(columns={index_column}) + + return loaded_df + + def dump(self, obj, file): + assert isinstance(obj, (pd.DataFrame)), f'requires pd.DataFrame, but {type(obj)} is passed.' + dump_obj = obj.copy() + + if self._store_index_in_feather: + index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}' + assert index_column_name not in dump_obj.columns, ( + f'column name {index_column_name} already exists in dump_obj. \ Consider not saving index by setting store_index_in_feather=False.' - ) - assert dump_obj.index.name != 'None', 'index name is "None", which is not allowed in gokart. Consider setting another index name.' + ) + assert dump_obj.index.name != 'None', 'index name is "None", which is not allowed in gokart. Consider setting another index name.' - dump_obj[index_column_name] = dump_obj.index - dump_obj = dump_obj.reset_index(drop=True) + dump_obj[index_column_name] = dump_obj.index + dump_obj = dump_obj.reset_index(drop=True) - # to_feather supports "binary" file-like object, but file variable is text - dump_obj.to_feather(file.name) - except ImportError: - import polars as pl + # to_feather supports "binary" file-like object, but file variable is text + dump_obj.to_feather(file.name) - assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.' - dump_obj = obj.copy() - dump_obj.write_ipc(file.name) +if DATAFRAME_FRAMEWORK == 'polars': + CsvFileProcessor = PolarsCsvFileProcessor + JsonFileProcessor = PolarsJsonFileProcessor + ParquetFileProcessor = PolarsParquetFileProcessor + FeatherFileProcessor = PolarsFeatherFileProcessor +else: + CsvFileProcessor = PandasCsvFileProcessor + JsonFileProcessor = PandasJsonFileProcessor + ParquetFileProcessor = PandasParquetFileProcessor + FeatherFileProcessor = PandasFeatherFileProcessor def make_file_processor(file_path: str, store_index_in_feather: bool) -> FileProcessor: extension2processor = { diff --git a/uv.lock b/uv.lock index 1da8c1a0..1ff3ae0e 100644 --- a/uv.lock +++ b/uv.lock @@ -616,6 +616,14 @@ dependencies = [ { name = "uritemplate" }, ] +[package.optional-dependencies] +pandas = [ + { name = "pandas" }, +] +polars = [ + { name = "polars" }, +] + [package.dev-dependencies] lint = [ { name = "mypy" }, @@ -648,6 +656,8 @@ requires-dist = [ { name = "luigi" }, { name = "numpy" }, { name = "pandas" }, + { name = "pandas", marker = "extra == 'pandas'" }, + { name = "polars", marker = "extra == 'polars'" }, { name = "pyarrow" }, { name = "redis" }, { name = "slack-sdk" }, @@ -1632,6 +1642,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, ] +[[package]] +name = "polars" +version = "1.25.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/56/d8a13c3a1990c92cc2c4f1887e97ea15aabf5685b1e826f875ca3e4e6c9e/polars-1.25.2.tar.gz", hash = "sha256:c6bd9b1b17c86e49bcf8aac44d2238b77e414d7df890afc3924812a5c989a4fe", size = 4501858 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/ec/61ae653b7848769baa5c5aaa00f3b3eaedaec56c3f1203a90dafe893a368/polars-1.25.2-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:59f2a34520ea4307a22e18b832310f8045a8a348606ca99ae785499b31eb4170", size = 34539929 }, + { url = "https://files.pythonhosted.org/packages/58/80/54f8cbb048558114ca519d7c40a994130c5a537246923ecce47cf269eaa6/polars-1.25.2-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:e9fe45bdc2327c2e2b64e8849a992b6d3bd4a7e7848b8a7a3a439cca9674dc87", size = 31326982 }, + { url = "https://files.pythonhosted.org/packages/cd/92/db411b7c83f694dca1b8348fa57a120c27c67cf622b85fa88c7ecf463adb/polars-1.25.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7fcbb4f476784384ccda48757fca4e8c2e2c5a0a3aef3717aaf56aee4e30e09", size = 35121263 }, + { url = "https://files.pythonhosted.org/packages/9f/a5/5ff200ce3bc643d5f12d91eddb9720fa083267c45fe395bcf0046e97cc2d/polars-1.25.2-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:9dd91885c9ee5ffad8725c8591f73fb7bd2632c740277ee641f0453176b3d4b8", size = 32254697 }, + { url = "https://files.pythonhosted.org/packages/70/d5/7a5458d05d5a0af816b1c7034aa1d026b7b8176a8de41e96dac70fcf29e2/polars-1.25.2-cp39-abi3-win_amd64.whl", hash = "sha256:a547796643b9a56cb2959be87d7cb87ff80a5c8ae9367f32fe1ad717039e9afc", size = 35318381 }, + { url = "https://files.pythonhosted.org/packages/24/df/60d35c4ae8ec357a5fb9914eb253bd1bad9e0f5332eda2bd2c6371dd3668/polars-1.25.2-cp39-abi3-win_arm64.whl", hash = "sha256:a2488e9d4b67bf47b18088f7264999180559e6ec2637ed11f9d0d4f98a74a37c", size = 31619833 }, +] + [[package]] name = "proto-plus" version = "1.26.0" From 665489a91575e6c702d8b0ee0b8209b6c5fedc5d Mon Sep 17 00:00:00 2001 From: hirosassa Date: Sun, 16 Mar 2025 08:56:15 +0900 Subject: [PATCH 03/13] format --- gokart/file_processor.py | 56 ++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/gokart/file_processor.py b/gokart/file_processor.py index eea3c30b..6d46e9b1 100644 --- a/gokart/file_processor.py +++ b/gokart/file_processor.py @@ -29,7 +29,7 @@ DATAFRAME_FRAMEWORK = 'pandas' -class FileProcessor(object): +class FileProcessor: @abstractmethod def format(self): pass @@ -139,11 +139,10 @@ def __init__(self, sep=',', encoding: str = 'utf-8'): def format(self): return TextFormat(encoding=self._encoding) - def load(self, file): - ... + def load(self, file): ... + + def dump(self, obj, file): ... - def dump(self, obj, file): - ... class PolarsCsvFileProcessor(CsvFileProcessor): def load(self, file): @@ -191,19 +190,17 @@ def __init__(self, orient: str | None = None): def format(self): return luigi.format.Nop - def load(self, file): - ... + def load(self, file): ... - def dump(self, obj, file): - ... + def dump(self, obj, file): ... class PolarsJsonFileProcessor(JsonFileProcessor): def load(self, file): try: if self._orient == 'records': - return self.read_ndjson(file) - return self.read_json(file) + return pl.read_ndjson(file) + return pl.read_json(file) except pl.exceptions.NoDataError: return pl.DataFrame() @@ -215,7 +212,7 @@ def dump(self, obj, file): obj = pl.from_dict(obj) if self._orient == 'records': - obj_write_ndjson(file) + obj.write_ndjson(file) else: obj.write_json(file) @@ -272,11 +269,10 @@ def __init__(self, engine='pyarrow', compression=None): def format(self): return luigi.format.Nop - def load(self, file): - ... + def load(self, file): ... + + def dump(self, obj, file): ... - def dump(self, obj, file): - ... class PolarsParquetFileProcessor(ParquetFileProcessor): def load(self, file): @@ -314,20 +310,17 @@ def __init__(self, store_index_in_feather: bool): def format(self): return luigi.format.Nop - def load(self, file): - ... + def load(self, file): ... - def dump(self, obj, file): - ... + def dump(self, obj, file): ... class PolarsFeatherFileProcessor(FeatherFileProcessor): def load(self, file): # Since polars' DataFrame doesn't have index, just load feather file if ObjectStorage.is_buffered_reader(file): - loaded_df = pl.read_ipc(file.name) - else: - loaded_df = pl.read_ipc(BytesIO(file.read())) + return pl.read_ipc(file.name) + return pl.read_ipc(BytesIO(file.read())) def dump(self, obj, file): assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.' @@ -378,15 +371,16 @@ def dump(self, obj, file): if DATAFRAME_FRAMEWORK == 'polars': - CsvFileProcessor = PolarsCsvFileProcessor - JsonFileProcessor = PolarsJsonFileProcessor - ParquetFileProcessor = PolarsParquetFileProcessor - FeatherFileProcessor = PolarsFeatherFileProcessor + CsvFileProcessor = PolarsCsvFileProcessor # type: ignore + JsonFileProcessor = PolarsJsonFileProcessor # type: ignore + ParquetFileProcessor = PolarsParquetFileProcessor # type: ignore + FeatherFileProcessor = PolarsFeatherFileProcessor # type: ignore else: - CsvFileProcessor = PandasCsvFileProcessor - JsonFileProcessor = PandasJsonFileProcessor - ParquetFileProcessor = PandasParquetFileProcessor - FeatherFileProcessor = PandasFeatherFileProcessor + CsvFileProcessor = PandasCsvFileProcessor # type: ignore + JsonFileProcessor = PandasJsonFileProcessor # type: ignore + ParquetFileProcessor = PandasParquetFileProcessor # type: ignore + FeatherFileProcessor = PandasFeatherFileProcessor # type: ignore + def make_file_processor(file_path: str, store_index_in_feather: bool) -> FileProcessor: extension2processor = { From 4ab9c351cdd16f0ceb00d331aeb5c8d465805b7a Mon Sep 17 00:00:00 2001 From: hirosassa Date: Sat, 22 Mar 2025 17:13:12 +0900 Subject: [PATCH 04/13] fix tests --- gokart/file_processor.py | 10 +- pyproject.toml | 3 +- test/test_file_processor.py | 156 ++------------------------- test/test_file_processor_pandas.py | 162 +++++++++++++++++++++++++++++ test/test_file_processor_polars.py | 100 ++++++++++++++++++ tox.ini | 4 + uv.lock | 24 ++--- 7 files changed, 292 insertions(+), 167 deletions(-) create mode 100644 test/test_file_processor_pandas.py create mode 100644 test/test_file_processor_polars.py diff --git a/gokart/file_processor.py b/gokart/file_processor.py index 6d46e9b1..35aaf1e5 100644 --- a/gokart/file_processor.py +++ b/gokart/file_processor.py @@ -147,7 +147,7 @@ def dump(self, obj, file): ... class PolarsCsvFileProcessor(CsvFileProcessor): def load(self, file): try: - return pl.read_csv(file, sep=self._sep, encoding=self._encoding) + return pl.read_csv(file, separator=self._sep, encoding=self._encoding) except pl.exceptions.NoDataError: return pl.DataFrame() @@ -201,7 +201,7 @@ def load(self, file): if self._orient == 'records': return pl.read_ndjson(file) return pl.read_json(file) - except pl.exceptions.NoDataError: + except pl.exceptions.ComputeError: return pl.DataFrame() def dump(self, obj, file): @@ -318,14 +318,16 @@ def dump(self, obj, file): ... class PolarsFeatherFileProcessor(FeatherFileProcessor): def load(self, file): # Since polars' DataFrame doesn't have index, just load feather file + # TODO: Fix ingnoring store_index_in_feather variable + # Currently in PolarsFeatherFileProcessor, we ignored store_index_in_feather variable to avoid + # a breaking change of FeatherFileProcessor's default behavior. if ObjectStorage.is_buffered_reader(file): return pl.read_ipc(file.name) return pl.read_ipc(BytesIO(file.read())) def dump(self, obj, file): assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.' - dump_obj = obj.copy() - dump_obj.write_ipc(file.name) + obj.write_ipc(file.name) class PandasFeatherFileProcessor(FeatherFileProcessor): diff --git a/pyproject.toml b/pyproject.toml index 00f51439..9f5535e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,8 +46,7 @@ Repository = "https://github.com/m3dev/gokart" Documentation = "https://gokart.readthedocs.io/en/latest/" [project.optional-dependencies] -pandas = ["pandas"] -polars = ["polars"] +polars = ["polars-lts-cpu"] [dependency-groups] test = [ diff --git a/test/test_file_processor.py b/test/test_file_processor.py index 7832dd6e..9a609f2f 100644 --- a/test/test_file_processor.py +++ b/test/test_file_processor.py @@ -6,126 +6,26 @@ from typing import Callable import boto3 -import pandas as pd -import pytest from luigi import LocalTarget from moto import mock_aws from gokart.file_processor import ( + DATAFRAME_FRAMEWORK, CsvFileProcessor, FeatherFileProcessor, GzipFileProcessor, JsonFileProcessor, NpzFileProcessor, + PandasCsvFileProcessor, ParquetFileProcessor, PickleFileProcessor, + PolarsCsvFileProcessor, TextFileProcessor, make_file_processor, ) from gokart.object_storage import ObjectStorage -class TestCsvFileProcessor(unittest.TestCase): - def test_dump_csv_with_utf8(self): - df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) - processor = CsvFileProcessor() - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.csv' - - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('w') as f: - processor.dump(df, f) - - # read with utf-8 to check if the file is dumped with utf8 - loaded_df = pd.read_csv(temp_path, encoding='utf-8') - pd.testing.assert_frame_equal(df, loaded_df) - - def test_dump_csv_with_cp932(self): - df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) - processor = CsvFileProcessor(encoding='cp932') - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.csv' - - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('w') as f: - processor.dump(df, f) - - # read with cp932 to check if the file is dumped with cp932 - loaded_df = pd.read_csv(temp_path, encoding='cp932') - pd.testing.assert_frame_equal(df, loaded_df) - - def test_load_csv_with_utf8(self): - df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) - processor = CsvFileProcessor() - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.csv' - df.to_csv(temp_path, encoding='utf-8', index=False) - - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('r') as f: - # read with utf-8 to check if the file is dumped with utf8 - loaded_df = processor.load(f) - pd.testing.assert_frame_equal(df, loaded_df) - - def test_load_csv_with_cp932(self): - df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) - processor = CsvFileProcessor(encoding='cp932') - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.csv' - df.to_csv(temp_path, encoding='cp932', index=False) - - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('r') as f: - # read with cp932 to check if the file is dumped with cp932 - loaded_df = processor.load(f) - pd.testing.assert_frame_equal(df, loaded_df) - - -class TestJsonFileProcessor: - @pytest.mark.parametrize( - 'orient,input_data,expected_json', - [ - pytest.param( - None, - pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}), - '{"A":{"0":1,"1":2,"2":3},"B":{"0":4,"1":5,"2":6}}', - id='With Default Orient for DataFrame', - ), - pytest.param( - 'records', - pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}), - '{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n', - id='With Records Orient for DataFrame', - ), - pytest.param(None, {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":{"0":1,"1":2,"2":3},"B":{"0":4,"1":5,"2":6}}', id='With Default Orient for Dict'), - pytest.param('records', {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n', id='With Records Orient for Dict'), - pytest.param(None, {}, '{}', id='With Default Orient for Empty Dict'), - pytest.param('records', {}, '\n', id='With Records Orient for Empty Dict'), - ], - ) - def test_dump_and_load_json(self, orient, input_data, expected_json): - processor = JsonFileProcessor(orient=orient) - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.json' - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('w') as f: - processor.dump(input_data, f) - with local_target.open('r') as f: - loaded_df = processor.load(f) - f.seek(0) - loaded_json = f.read().decode('utf-8') - - assert loaded_json == expected_json - - df_input = pd.DataFrame(input_data) - pd.testing.assert_frame_equal(df_input, loaded_df) - - class TestPickleFileProcessor(unittest.TestCase): def test_dump_and_load_normal_obj(self): var = 'abc' @@ -189,50 +89,12 @@ def test_dump_and_load_with_readables3file(self): self.assertEqual(loaded, var) -class TestFeatherFileProcessor(unittest.TestCase): - def test_feather_should_return_same_dataframe(self): - df = pd.DataFrame({'a': [1]}) - processor = FeatherFileProcessor(store_index_in_feather=True) - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.feather' - - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('w') as f: - processor.dump(df, f) - - with local_target.open('r') as f: - loaded_df = processor.load(f) - - pd.testing.assert_frame_equal(df, loaded_df) - - def test_feather_should_save_index_name(self): - df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='index_name')) - processor = FeatherFileProcessor(store_index_in_feather=True) - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.feather' - - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('w') as f: - processor.dump(df, f) - - with local_target.open('r') as f: - loaded_df = processor.load(f) - - pd.testing.assert_frame_equal(df, loaded_df) - - def test_feather_should_raise_error_index_name_is_None(self): - df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='None')) - processor = FeatherFileProcessor(store_index_in_feather=True) - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = f'{temp_dir}/temp.feather' - - local_target = LocalTarget(path=temp_path, format=processor.format()) - with local_target.open('w') as f: - with self.assertRaises(AssertionError): - processor.dump(df, f) +class TestFileProcessorClassSelection(unittest.TestCase): + def test_processor_selection(self): + if DATAFRAME_FRAMEWORK == 'polars': + self.assertTrue(issubclass(CsvFileProcessor, PolarsCsvFileProcessor)) + else: + self.assertTrue(issubclass(CsvFileProcessor, PandasCsvFileProcessor)) class TestMakeFileProcessor(unittest.TestCase): diff --git a/test/test_file_processor_pandas.py b/test/test_file_processor_pandas.py new file mode 100644 index 00000000..ad892ff6 --- /dev/null +++ b/test/test_file_processor_pandas.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import importlib +import tempfile + +import pandas as pd +import pytest +from luigi import LocalTarget + +from gokart.file_processor import PandasCsvFileProcessor, PandasFeatherFileProcessor, PandasJsonFileProcessor + +polars_installed = importlib.util.find_spec('polars') is not None +pytestmark = pytest.mark.skipif(polars_installed, reason='polars installed, skip pandas tests') + + +def test_dump_csv_with_utf8(): + df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) + processor = PandasCsvFileProcessor() + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.csv' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + # read with utf-8 to check if the file is dumped with utf8 + loaded_df = pd.read_csv(temp_path, encoding='utf-8') + pd.testing.assert_frame_equal(df, loaded_df) + + +def test_dump_csv_with_cp932(): + df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) + processor = PandasCsvFileProcessor(encoding='cp932') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.csv' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + # read with cp932 to check if the file is dumped with cp932 + loaded_df = pd.read_csv(temp_path, encoding='cp932') + pd.testing.assert_frame_equal(df, loaded_df) + + +def test_load_csv_with_utf8(): + df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) + processor = PandasCsvFileProcessor() + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.csv' + df.to_csv(temp_path, encoding='utf-8', index=False) + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('r') as f: + # read with utf-8 to check if the file is dumped with utf8 + loaded_df = processor.load(f) + pd.testing.assert_frame_equal(df, loaded_df) + + +def test_load_csv_with_cp932(): + df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) + processor = PandasCsvFileProcessor(encoding='cp932') + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.csv' + df.to_csv(temp_path, encoding='cp932', index=False) + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('r') as f: + # read with cp932 to check if the file is dumped with cp932 + loaded_df = processor.load(f) + pd.testing.assert_frame_equal(df, loaded_df) + + +@pytest.mark.parametrize( + 'orient,input_data,expected_json', + [ + pytest.param( + None, + pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}), + '{"A":{"0":1,"1":2,"2":3},"B":{"0":4,"1":5,"2":6}}', + id='With Default Orient for DataFrame', + ), + pytest.param( + 'records', + pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}), + '{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n', + id='With Records Orient for DataFrame', + ), + pytest.param(None, {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":{"0":1,"1":2,"2":3},"B":{"0":4,"1":5,"2":6}}', id='With Default Orient for Dict'), + pytest.param('records', {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n', id='With Records Orient for Dict'), + pytest.param(None, {}, '{}', id='With Default Orient for Empty Dict'), + pytest.param('records', {}, '\n', id='With Records Orient for Empty Dict'), + ], +) +def test_dump_and_load_json(orient, input_data, expected_json): + processor = PandasJsonFileProcessor(orient=orient) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.json' + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(input_data, f) + with local_target.open('r') as f: + loaded_df = processor.load(f) + f.seek(0) + loaded_json = f.read().decode('utf-8') + + assert loaded_json == expected_json + + df_input = pd.DataFrame(input_data) + pd.testing.assert_frame_equal(df_input, loaded_df) + + +def test_feather_should_return_same_dataframe(): + df = pd.DataFrame({'a': [1]}) + processor = PandasFeatherFileProcessor(store_index_in_feather=True) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.feather' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + with local_target.open('r') as f: + loaded_df = processor.load(f) + + pd.testing.assert_frame_equal(df, loaded_df) + + +def test_feather_should_save_index_name(): + df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='index_name')) + processor = PandasFeatherFileProcessor(store_index_in_feather=True) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.feather' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + with local_target.open('r') as f: + loaded_df = processor.load(f) + + pd.testing.assert_frame_equal(df, loaded_df) + + +def test_feather_should_raise_error_index_name_is_None(): + df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='None')) + processor = PandasFeatherFileProcessor(store_index_in_feather=True) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.feather' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + with pytest.raises(AssertionError): + processor.dump(df, f) diff --git a/test/test_file_processor_polars.py b/test/test_file_processor_polars.py new file mode 100644 index 00000000..8f3a70f6 --- /dev/null +++ b/test/test_file_processor_polars.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import tempfile + +import pytest +from luigi import LocalTarget + +from gokart.file_processor import PolarsCsvFileProcessor, PolarsFeatherFileProcessor, PolarsJsonFileProcessor + +pl = pytest.importorskip('polars', reason='polars required') +pl_testing = pytest.importorskip('polars.testing', reason='polars required') + + +def test_dump_csv_with(): + df = pl.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) + processor = PolarsCsvFileProcessor() + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.csv' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + # read with utf-8 to check if the file is dumped with utf8 + loaded_df = pl.read_csv(temp_path) + pl_testing.assert_frame_equal(df, loaded_df) + + +def test_load_csv(): + df = pl.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) + processor = PolarsCsvFileProcessor() + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.csv' + df.write_csv(temp_path, include_header=True) + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('r') as f: + # read with utf-8 to check if the file is dumped with utf8 + loaded_df = processor.load(f) + pl_testing.assert_frame_equal(df, loaded_df) + + +@pytest.mark.parametrize( + 'orient,input_data,expected_json', + [ + pytest.param( + None, + pl.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}), + '[{"A":1,"B":4},{"A":2,"B":5},{"A":3,"B":6}]', + id='With Default Orient for DataFrame', + ), + pytest.param( + 'records', + pl.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}), + '{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n', + id='With Records Orient for DataFrame', + ), + pytest.param(None, {'A': [1, 2, 3], 'B': [4, 5, 6]}, '[{"A":1,"B":4},{"A":2,"B":5},{"A":3,"B":6}]', id='With Default Orient for Dict'), + pytest.param('records', {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n', id='With Records Orient for Dict'), + pytest.param(None, {}, '[]', id='With Default Orient for Empty Dict'), + pytest.param('records', {}, '', id='With Records Orient for Empty Dict'), + ], +) +def test_dump_and_load_json(orient, input_data, expected_json): + processor = PolarsJsonFileProcessor(orient=orient) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.json' + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(input_data, f) + with local_target.open('r') as f: + loaded_df = processor.load(f) + f.seek(0) + loaded_json = f.read().decode('utf-8') + + assert loaded_json == expected_json + + df_input = pl.DataFrame(input_data) + pl_testing.assert_frame_equal(df_input, loaded_df) + + +def test_feather_should_return_same_dataframe(): + df = pl.DataFrame({'a': [1]}) + # TODO: currently we set store_index_in_feather True but it is ignored + processor = PolarsFeatherFileProcessor(store_index_in_feather=True) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = f'{temp_dir}/temp.feather' + + local_target = LocalTarget(path=temp_path, format=processor.format()) + with local_target.open('w') as f: + processor.dump(df, f) + + with local_target.open('r') as f: + loaded_df = processor.load(f) + + pl_testing.assert_frame_equal(df, loaded_df) diff --git a/tox.ini b/tox.ini index 5fcd7251..585b3d8c 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,7 @@ [tox] envlist = py{39,310,311,312,313},ruff,mypy +labels = + polars = py{39,310,311,312,313}-polars skipsdits = True [testenv] @@ -7,6 +9,8 @@ runner = uv-venv-lock-runner dependency_groups = test commands = {envpython} -m pytest --cov=gokart --cov-report=xml -vv {posargs:} +extras = + polars: polars [testenv:ruff] dependency_groups = lint diff --git a/uv.lock b/uv.lock index 1ff3ae0e..6e6769d6 100644 --- a/uv.lock +++ b/uv.lock @@ -617,11 +617,8 @@ dependencies = [ ] [package.optional-dependencies] -pandas = [ - { name = "pandas" }, -] polars = [ - { name = "polars" }, + { name = "polars-lts-cpu" }, ] [package.dev-dependencies] @@ -656,8 +653,7 @@ requires-dist = [ { name = "luigi" }, { name = "numpy" }, { name = "pandas" }, - { name = "pandas", marker = "extra == 'pandas'" }, - { name = "polars", marker = "extra == 'polars'" }, + { name = "polars-lts-cpu", marker = "extra == 'polars'" }, { name = "pyarrow" }, { name = "redis" }, { name = "slack-sdk" }, @@ -1643,17 +1639,17 @@ wheels = [ ] [[package]] -name = "polars" +name = "polars-lts-cpu" version = "1.25.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/57/56/d8a13c3a1990c92cc2c4f1887e97ea15aabf5685b1e826f875ca3e4e6c9e/polars-1.25.2.tar.gz", hash = "sha256:c6bd9b1b17c86e49bcf8aac44d2238b77e414d7df890afc3924812a5c989a4fe", size = 4501858 } +sdist = { url = "https://files.pythonhosted.org/packages/ea/ac/399f210dd334e70eb4a0df25864a97512fa61bea352bdd8a6285c4fbbb63/polars_lts_cpu-1.25.2.tar.gz", hash = "sha256:caf4764ceb94457f96166af16d060311bd742cf20a850a8651eac0643a7f41dd", size = 4501478 } wheels = [ - { url = "https://files.pythonhosted.org/packages/bd/ec/61ae653b7848769baa5c5aaa00f3b3eaedaec56c3f1203a90dafe893a368/polars-1.25.2-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:59f2a34520ea4307a22e18b832310f8045a8a348606ca99ae785499b31eb4170", size = 34539929 }, - { url = "https://files.pythonhosted.org/packages/58/80/54f8cbb048558114ca519d7c40a994130c5a537246923ecce47cf269eaa6/polars-1.25.2-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:e9fe45bdc2327c2e2b64e8849a992b6d3bd4a7e7848b8a7a3a439cca9674dc87", size = 31326982 }, - { url = "https://files.pythonhosted.org/packages/cd/92/db411b7c83f694dca1b8348fa57a120c27c67cf622b85fa88c7ecf463adb/polars-1.25.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7fcbb4f476784384ccda48757fca4e8c2e2c5a0a3aef3717aaf56aee4e30e09", size = 35121263 }, - { url = "https://files.pythonhosted.org/packages/9f/a5/5ff200ce3bc643d5f12d91eddb9720fa083267c45fe395bcf0046e97cc2d/polars-1.25.2-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:9dd91885c9ee5ffad8725c8591f73fb7bd2632c740277ee641f0453176b3d4b8", size = 32254697 }, - { url = "https://files.pythonhosted.org/packages/70/d5/7a5458d05d5a0af816b1c7034aa1d026b7b8176a8de41e96dac70fcf29e2/polars-1.25.2-cp39-abi3-win_amd64.whl", hash = "sha256:a547796643b9a56cb2959be87d7cb87ff80a5c8ae9367f32fe1ad717039e9afc", size = 35318381 }, - { url = "https://files.pythonhosted.org/packages/24/df/60d35c4ae8ec357a5fb9914eb253bd1bad9e0f5332eda2bd2c6371dd3668/polars-1.25.2-cp39-abi3-win_arm64.whl", hash = "sha256:a2488e9d4b67bf47b18088f7264999180559e6ec2637ed11f9d0d4f98a74a37c", size = 31619833 }, + { url = "https://files.pythonhosted.org/packages/e7/39/0cb03a21d38bc152e09046437990fbcd1973c326c24ce15db88f0a028760/polars_lts_cpu-1.25.2-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:09a49dbf9782f72ed5cbf3369f26349d3ba98b209688dd8748ac2a1ae14bcf73", size = 34139766 }, + { url = "https://files.pythonhosted.org/packages/29/41/4417553a00e3f0e696b747df34cd4546fe261da2bf99f7ffd88e0d5071c7/polars_lts_cpu-1.25.2-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:e230385b6b347465196343cd6f167451ca9bbc3a382e9db5f231977d5ab59f00", size = 31327093 }, + { url = "https://files.pythonhosted.org/packages/ee/92/341df9d602dabca91fe9c5e7c0ab04c23f3c2c77bc6324b9af12d2637990/polars_lts_cpu-1.25.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d86c7a09f9afb68b0a148ce0d747c47e0c7416ec8195fa37be9fa5cb253dfe8", size = 34720087 }, + { url = "https://files.pythonhosted.org/packages/ac/30/6bce9e661d65c884c8b6fa952b29324e7cb382d7e982d9b58608d7abfc4e/polars_lts_cpu-1.25.2-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:c4bf8ca23e6fe61d39dc2e1aec5d7de4def772c4e903d4b0b2c775f029b1aadb", size = 32254769 }, + { url = "https://files.pythonhosted.org/packages/0e/27/5ad74666b3a9053b6d93949854dfb686bccecfb9bdccad62d198d96e0f9c/polars_lts_cpu-1.25.2-cp39-abi3-win_amd64.whl", hash = "sha256:319e8321f8428f38cfdd1d7a1295d62dbbfcd2fa0837daf897e7c0fd34bcc472", size = 35179860 }, + { url = "https://files.pythonhosted.org/packages/82/0c/15b187a05bd7feb7f79fd36d6c917bd62b335fd10fb381d4aa9f70849767/polars_lts_cpu-1.25.2-cp39-abi3-win_arm64.whl", hash = "sha256:eb9aa9b6b9935224cb304576c955bd5e107d6b5251f53f5f55abfc51e42ce23f", size = 31619912 }, ] [[package]] From c55c3abb9c6d4d7e207d7360271fbc04e91393f0 Mon Sep 17 00:00:00 2001 From: hirosassa Date: Sat, 22 Mar 2025 17:49:21 +0900 Subject: [PATCH 05/13] fix workflow --- .github/workflows/test.yml | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cd5d08ba..ec141f77 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,17 +11,9 @@ jobs: strategy: max-parallel: 7 matrix: + platform: ["ubuntu-latest"] + tox-env: ["py39", "py310", "py311", "py312", "py313"] include: - - platform: ubuntu-latest - tox-env: "py39" - - platform: ubuntu-latest - tox-env: "py310" - - platform: ubuntu-latest - tox-env: "py311" - - platform: ubuntu-latest - tox-env: "py312" - - platform: ubuntu-latest - tox-env: "py313" # test only on latest python for macos - platform: macos-13 tox-env: "py313" @@ -38,3 +30,5 @@ jobs: uv tool install --python-preference only-managed --python 3.12 tox --with tox-uv - name: Test with tox run: uvx --with tox-uv tox run -e ${{ matrix.tox-env }} + - name: Test with tox for polars extra + run: uvx --with tox-uv tox run -e ${{ matrix.tox-env }}-polars From 94e9fc8742012038dbdfc51aea1acc45d1a4018f Mon Sep 17 00:00:00 2001 From: hirosassa Date: Sat, 22 Mar 2025 20:08:08 +0900 Subject: [PATCH 06/13] fix tests --- test/test_file_processor_pandas.py | 2 +- test/test_target.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/test/test_file_processor_pandas.py b/test/test_file_processor_pandas.py index ad892ff6..c2a519e5 100644 --- a/test/test_file_processor_pandas.py +++ b/test/test_file_processor_pandas.py @@ -1,6 +1,6 @@ from __future__ import annotations -import importlib +import importlib.util import tempfile import pandas as pd diff --git a/test/test_target.py b/test/test_target.py index 2d82e76f..9d82f721 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -1,3 +1,4 @@ +import importlib.util import io import os import shutil @@ -8,6 +9,7 @@ import boto3 import numpy as np import pandas as pd +import pytest from matplotlib import pyplot from moto import mock_aws @@ -15,6 +17,9 @@ from gokart.target import make_model_target, make_target from test.util import _get_temporary_directory +polars_installed = importlib.util.find_spec('polars') is not None +pytestmark = pytest.mark.skipif(polars_installed, reason='polars installed, skip pandas tests') + class LocalTargetTest(unittest.TestCase): def setUp(self): From 1a344e175ac22c29bea06b3f9070e288d9687a42 Mon Sep 17 00:00:00 2001 From: hirosassa Date: Sat, 22 Mar 2025 21:02:28 +0900 Subject: [PATCH 07/13] fix dependencies --- pyproject.toml | 6 +++++- uv.lock | 22 ++++++++++++++++------ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9f5535e5..9e60be7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,11 @@ Repository = "https://github.com/m3dev/gokart" Documentation = "https://gokart.readthedocs.io/en/latest/" [project.optional-dependencies] -polars = ["polars-lts-cpu"] +polars = [ + # polars doesn't run correctly on Apple silicon chips + "polars-lts-cpu; platform_system =='Darwin'", + "polars; platform_system != 'Darwin'", +] [dependency-groups] test = [ diff --git a/uv.lock b/uv.lock index 6e6769d6..da9f8cc0 100644 --- a/uv.lock +++ b/uv.lock @@ -618,7 +618,8 @@ dependencies = [ [package.optional-dependencies] polars = [ - { name = "polars-lts-cpu" }, + { name = "polars", marker = "sys_platform != 'darwin'" }, + { name = "polars-lts-cpu", marker = "sys_platform == 'darwin'" }, ] [package.dev-dependencies] @@ -653,7 +654,8 @@ requires-dist = [ { name = "luigi" }, { name = "numpy" }, { name = "pandas" }, - { name = "polars-lts-cpu", marker = "extra == 'polars'" }, + { name = "polars", marker = "sys_platform != 'darwin' and extra == 'polars'" }, + { name = "polars-lts-cpu", marker = "sys_platform == 'darwin' and extra == 'polars'" }, { name = "pyarrow" }, { name = "redis" }, { name = "slack-sdk" }, @@ -1638,6 +1640,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, ] +[[package]] +name = "polars" +version = "1.25.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/56/d8a13c3a1990c92cc2c4f1887e97ea15aabf5685b1e826f875ca3e4e6c9e/polars-1.25.2.tar.gz", hash = "sha256:c6bd9b1b17c86e49bcf8aac44d2238b77e414d7df890afc3924812a5c989a4fe", size = 4501858 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cd/92/db411b7c83f694dca1b8348fa57a120c27c67cf622b85fa88c7ecf463adb/polars-1.25.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7fcbb4f476784384ccda48757fca4e8c2e2c5a0a3aef3717aaf56aee4e30e09", size = 35121263 }, + { url = "https://files.pythonhosted.org/packages/9f/a5/5ff200ce3bc643d5f12d91eddb9720fa083267c45fe395bcf0046e97cc2d/polars-1.25.2-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:9dd91885c9ee5ffad8725c8591f73fb7bd2632c740277ee641f0453176b3d4b8", size = 32254697 }, + { url = "https://files.pythonhosted.org/packages/70/d5/7a5458d05d5a0af816b1c7034aa1d026b7b8176a8de41e96dac70fcf29e2/polars-1.25.2-cp39-abi3-win_amd64.whl", hash = "sha256:a547796643b9a56cb2959be87d7cb87ff80a5c8ae9367f32fe1ad717039e9afc", size = 35318381 }, + { url = "https://files.pythonhosted.org/packages/24/df/60d35c4ae8ec357a5fb9914eb253bd1bad9e0f5332eda2bd2c6371dd3668/polars-1.25.2-cp39-abi3-win_arm64.whl", hash = "sha256:a2488e9d4b67bf47b18088f7264999180559e6ec2637ed11f9d0d4f98a74a37c", size = 31619833 }, +] + [[package]] name = "polars-lts-cpu" version = "1.25.2" @@ -1646,10 +1660,6 @@ sdist = { url = "https://files.pythonhosted.org/packages/ea/ac/399f210dd334e70eb wheels = [ { url = "https://files.pythonhosted.org/packages/e7/39/0cb03a21d38bc152e09046437990fbcd1973c326c24ce15db88f0a028760/polars_lts_cpu-1.25.2-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:09a49dbf9782f72ed5cbf3369f26349d3ba98b209688dd8748ac2a1ae14bcf73", size = 34139766 }, { url = "https://files.pythonhosted.org/packages/29/41/4417553a00e3f0e696b747df34cd4546fe261da2bf99f7ffd88e0d5071c7/polars_lts_cpu-1.25.2-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:e230385b6b347465196343cd6f167451ca9bbc3a382e9db5f231977d5ab59f00", size = 31327093 }, - { url = "https://files.pythonhosted.org/packages/ee/92/341df9d602dabca91fe9c5e7c0ab04c23f3c2c77bc6324b9af12d2637990/polars_lts_cpu-1.25.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d86c7a09f9afb68b0a148ce0d747c47e0c7416ec8195fa37be9fa5cb253dfe8", size = 34720087 }, - { url = "https://files.pythonhosted.org/packages/ac/30/6bce9e661d65c884c8b6fa952b29324e7cb382d7e982d9b58608d7abfc4e/polars_lts_cpu-1.25.2-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:c4bf8ca23e6fe61d39dc2e1aec5d7de4def772c4e903d4b0b2c775f029b1aadb", size = 32254769 }, - { url = "https://files.pythonhosted.org/packages/0e/27/5ad74666b3a9053b6d93949854dfb686bccecfb9bdccad62d198d96e0f9c/polars_lts_cpu-1.25.2-cp39-abi3-win_amd64.whl", hash = "sha256:319e8321f8428f38cfdd1d7a1295d62dbbfcd2fa0837daf897e7c0fd34bcc472", size = 35179860 }, - { url = "https://files.pythonhosted.org/packages/82/0c/15b187a05bd7feb7f79fd36d6c917bd62b335fd10fb381d4aa9f70849767/polars_lts_cpu-1.25.2-cp39-abi3-win_arm64.whl", hash = "sha256:eb9aa9b6b9935224cb304576c955bd5e107d6b5251f53f5f55abfc51e42ce23f", size = 31619912 }, ] [[package]] From d70f28530734f42b71fec5b0501e68bbeffbd1d0 Mon Sep 17 00:00:00 2001 From: hirosassa Date: Sat, 22 Mar 2025 21:42:50 +0900 Subject: [PATCH 08/13] docs --- docs/for_polars.rst | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 docs/for_polars.rst diff --git a/docs/for_polars.rst b/docs/for_polars.rst new file mode 100644 index 00000000..3582da7c --- /dev/null +++ b/docs/for_polars.rst @@ -0,0 +1,11 @@ +For Pandas +========== + +Gokart also has features for Polars. It is enabled by installing extra using following command: + +.. code:: sh + + pip install gokart[polars] + + +You can use Polars for the most of the file format used in :func:`~gokart.task.TaskOnKart.load` and :func:`~gokart.task.TaskOnKart.dump` feature. From 8a9063fd551fd6db866bf12fe4ee4b1f6f1dfdcd Mon Sep 17 00:00:00 2001 From: hirosassa Date: Sat, 29 Mar 2025 14:27:45 +0900 Subject: [PATCH 09/13] use env var to switch frameworks --- docs/for_polars.rst | 2 +- gokart/file_processor.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/for_polars.rst b/docs/for_polars.rst index 3582da7c..91aa194f 100644 --- a/docs/for_polars.rst +++ b/docs/for_polars.rst @@ -8,4 +8,4 @@ Gokart also has features for Polars. It is enabled by installing extra using fol pip install gokart[polars] -You can use Polars for the most of the file format used in :func:`~gokart.task.TaskOnKart.load` and :func:`~gokart.task.TaskOnKart.dump` feature. +You need to set the environment variable ``GOKART_DATAFRAME_FRAMWORK_POLARS_ENABLED`` as ``true`` and you can use Polars for the most of the file format used in :func:`~gokart.task.TaskOnKart.load` and :func:`~gokart.task.TaskOnKart.dump` feature. diff --git a/gokart/file_processor.py b/gokart/file_processor.py index 35aaf1e5..a66bafad 100644 --- a/gokart/file_processor.py +++ b/gokart/file_processor.py @@ -10,6 +10,7 @@ import luigi import luigi.contrib.s3 import luigi.format +import pandas as pd import numpy as np from luigi.format import TextFormat @@ -22,10 +23,11 @@ try: import polars as pl - DATAFRAME_FRAMEWORK = 'polars' -except ImportError: - import pandas as pd - + if os.getenv('GOKART_DATAFRAME_FRAMEWORK_POLARS_ENABLED') == 'true': + DATAFRAME_FRAMEWORK = 'polars' + else: + raise ValueError('GOKART_DATAFRAME_FRAMEWORK is not set. Use pandas as dataframe framework.') +except (ImportError, ValueError): DATAFRAME_FRAMEWORK = 'pandas' From 33fb5ef380c21d186020576e8cf40078f1208d78 Mon Sep 17 00:00:00 2001 From: hirosassa Date: Sat, 29 Mar 2025 14:35:16 +0900 Subject: [PATCH 10/13] use env var to switch frameworks --- gokart/file_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gokart/file_processor.py b/gokart/file_processor.py index a66bafad..447f2f6b 100644 --- a/gokart/file_processor.py +++ b/gokart/file_processor.py @@ -10,8 +10,8 @@ import luigi import luigi.contrib.s3 import luigi.format -import pandas as pd import numpy as np +import pandas as pd from luigi.format import TextFormat from gokart.object_storage import ObjectStorage From 2a881892b4610e4dd47744e1d8527f9d43ffe0db Mon Sep 17 00:00:00 2001 From: hirosassa Date: Sat, 29 Mar 2025 14:39:58 +0900 Subject: [PATCH 11/13] fix typo --- gokart/file_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gokart/file_processor.py b/gokart/file_processor.py index 447f2f6b..463d4773 100644 --- a/gokart/file_processor.py +++ b/gokart/file_processor.py @@ -26,7 +26,7 @@ if os.getenv('GOKART_DATAFRAME_FRAMEWORK_POLARS_ENABLED') == 'true': DATAFRAME_FRAMEWORK = 'polars' else: - raise ValueError('GOKART_DATAFRAME_FRAMEWORK is not set. Use pandas as dataframe framework.') + raise ValueError('GOKART_DATAFRAME_FRAMEWORK_POLARS_ENABLED is not set. Use pandas as dataframe framework.') except (ImportError, ValueError): DATAFRAME_FRAMEWORK = 'pandas' From 79839e1ac4bb9facd1e8cc24b0ef9613082a43d0 Mon Sep 17 00:00:00 2001 From: hirosassa Date: Sat, 29 Mar 2025 22:16:15 +0900 Subject: [PATCH 12/13] raise exception if the users doesn't install polars when user set the env var as polars --- docs/for_polars.rst | 3 ++- gokart/file_processor.py | 16 +++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/docs/for_polars.rst b/docs/for_polars.rst index 91aa194f..327d67fd 100644 --- a/docs/for_polars.rst +++ b/docs/for_polars.rst @@ -8,4 +8,5 @@ Gokart also has features for Polars. It is enabled by installing extra using fol pip install gokart[polars] -You need to set the environment variable ``GOKART_DATAFRAME_FRAMWORK_POLARS_ENABLED`` as ``true`` and you can use Polars for the most of the file format used in :func:`~gokart.task.TaskOnKart.load` and :func:`~gokart.task.TaskOnKart.dump` feature. +You need to set the environment variable ``GOKART_DATAFRAME_FRAMEWORK`` as ``polars`` and you can use Polars for the most of the file format used in :func:`~gokart.task.TaskOnKart.load` and :func:`~gokart.task.TaskOnKart.dump` feature. +If you don't set ``GOKART_DATAFRAME_FRAMEWORK`` or set it as ``pandas``, you can use pandas for it. diff --git a/gokart/file_processor.py b/gokart/file_processor.py index 463d4773..e0001e85 100644 --- a/gokart/file_processor.py +++ b/gokart/file_processor.py @@ -20,15 +20,13 @@ logger = getLogger(__name__) -try: - import polars as pl - - if os.getenv('GOKART_DATAFRAME_FRAMEWORK_POLARS_ENABLED') == 'true': - DATAFRAME_FRAMEWORK = 'polars' - else: - raise ValueError('GOKART_DATAFRAME_FRAMEWORK_POLARS_ENABLED is not set. Use pandas as dataframe framework.') -except (ImportError, ValueError): - DATAFRAME_FRAMEWORK = 'pandas' +DATAFRAME_FRAMEWORK = os.getenv('GOKART_DATAFRAME_FRAMEWORK', 'pandas') +if DATAFRAME_FRAMEWORK == 'polars': + try: + import polars as pl + + except ImportError as e: + raise ValueError('please install polars to use polars as a framework of dataframe for gokart') from e class FileProcessor: From b200fd9d3911f941fc65f40b08016e6dc5b656a2 Mon Sep 17 00:00:00 2001 From: hirosassa Date: Sun, 30 Mar 2025 06:17:35 +0900 Subject: [PATCH 13/13] add fixture for reloading the modules --- test/test_file_processor_polars.py | 31 ++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/test/test_file_processor_polars.py b/test/test_file_processor_polars.py index 8f3a70f6..5508b977 100644 --- a/test/test_file_processor_polars.py +++ b/test/test_file_processor_polars.py @@ -1,17 +1,37 @@ from __future__ import annotations +import importlib import tempfile import pytest from luigi import LocalTarget +from gokart import file_processor from gokart.file_processor import PolarsCsvFileProcessor, PolarsFeatherFileProcessor, PolarsJsonFileProcessor pl = pytest.importorskip('polars', reason='polars required') pl_testing = pytest.importorskip('polars.testing', reason='polars required') -def test_dump_csv_with(): +@pytest.fixture +def reload_processor(monkeypatch): + """ + A pytest fixture that reloads the `gokart.file_processor` module after modifying + the environment variable `GOKART_DATAFRAME_FRAMEWORK`. This ensures that polars + is used when reloading the module. + + Returns: + Tuple[Type[PolarsCsvFileProcessor], Type[PolarsFeatherFileProcessor], Type[PolarsJsonFileProcessor]]: + The reloaded classes from the `gokart.file_processor` module. + """ + monkeypatch.setenv('GOKART_DATAFRAME_FRAMEWORK', 'polars') + importlib.reload(file_processor) + + yield PolarsCsvFileProcessor, PolarsFeatherFileProcessor, PolarsJsonFileProcessor + + +def test_dump_csv(reload_processor): + PolarsCsvFileProcessor, _, _ = reload_processor df = pl.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) processor = PolarsCsvFileProcessor() @@ -27,7 +47,8 @@ def test_dump_csv_with(): pl_testing.assert_frame_equal(df, loaded_df) -def test_load_csv(): +def test_load_csv(reload_processor): + PolarsCsvFileProcessor, _, _ = reload_processor df = pl.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) processor = PolarsCsvFileProcessor() @@ -63,7 +84,8 @@ def test_load_csv(): pytest.param('records', {}, '', id='With Records Orient for Empty Dict'), ], ) -def test_dump_and_load_json(orient, input_data, expected_json): +def test_dump_and_load_json(reload_processor, orient, input_data, expected_json): + _, _, PolarsJsonFileProcessor = reload_processor processor = PolarsJsonFileProcessor(orient=orient) with tempfile.TemporaryDirectory() as temp_dir: @@ -82,7 +104,8 @@ def test_dump_and_load_json(orient, input_data, expected_json): pl_testing.assert_frame_equal(df_input, loaded_df) -def test_feather_should_return_same_dataframe(): +def test_feather_should_return_same_dataframe(reload_processor): + _, PolarsFeatherFileProcessor, _ = reload_processor df = pl.DataFrame({'a': [1]}) # TODO: currently we set store_index_in_feather True but it is ignored processor = PolarsFeatherFileProcessor(store_index_in_feather=True)