Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@ jobs:
uv tool install --python-preference only-managed --python 3.12 tox --with tox-uv
- name: Test with tox
run: uvx --with tox-uv tox run -e ${{ matrix.tox-env }}
- name: Test with tox for polars extra
run: uvx --with tox-uv tox run -e ${{ matrix.tox-env }}-polars
Comment on lines +32 to +33
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the test run with polars extra

12 changes: 12 additions & 0 deletions docs/for_polars.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
For Pandas
==========

Gokart also has features for Polars. It is enabled by installing extra using following command:

.. code:: sh

pip install gokart[polars]


You need to set the environment variable ``GOKART_DATAFRAME_FRAMEWORK`` as ``polars`` and you can use Polars for the most of the file format used in :func:`~gokart.task.TaskOnKart.load` and :func:`~gokart.task.TaskOnKart.dump` feature.
If you don't set ``GOKART_DATAFRAME_FRAMEWORK`` or set it as ``pandas``, you can use pandas for it.
113 changes: 108 additions & 5 deletions gokart/file_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import luigi.format
import numpy as np
import pandas as pd
import pandas.errors
from luigi.format import TextFormat

from gokart.object_storage import ObjectStorage
Expand All @@ -21,6 +20,15 @@
logger = getLogger(__name__)


DATAFRAME_FRAMEWORK = os.getenv('GOKART_DATAFRAME_FRAMEWORK', 'pandas')
if DATAFRAME_FRAMEWORK == 'polars':
try:
import polars as pl

except ImportError as e:
raise ValueError('please install polars to use polars as a framework of dataframe for gokart') from e


class FileProcessor:
@abstractmethod
def format(self):
Expand Down Expand Up @@ -131,6 +139,24 @@ def __init__(self, sep=',', encoding: str = 'utf-8'):
def format(self):
return TextFormat(encoding=self._encoding)

def load(self, file): ...

def dump(self, obj, file): ...


class PolarsCsvFileProcessor(CsvFileProcessor):
def load(self, file):
try:
return pl.read_csv(file, separator=self._sep, encoding=self._encoding)
except pl.exceptions.NoDataError:
return pl.DataFrame()

def dump(self, obj, file):
assert isinstance(obj, (pl.DataFrame, pl.Series)), f'requires pl.DataFrame or pl.Series, but {type(obj)} is passed.'
obj.write_csv(file, separator=self._sep, include_header=True)


class PandasCsvFileProcessor(CsvFileProcessor):
def load(self, file):
try:
return pd.read_csv(file, sep=self._sep, encoding=self._encoding)
Expand Down Expand Up @@ -164,6 +190,34 @@ def __init__(self, orient: str | None = None):
def format(self):
return luigi.format.Nop

def load(self, file): ...

def dump(self, obj, file): ...


class PolarsJsonFileProcessor(JsonFileProcessor):
def load(self, file):
try:
if self._orient == 'records':
return pl.read_ndjson(file)
return pl.read_json(file)
except pl.exceptions.ComputeError:
return pl.DataFrame()

def dump(self, obj, file):
assert isinstance(obj, pl.DataFrame) or isinstance(obj, pl.Series) or isinstance(obj, dict), (
f'requires pl.DataFrame or pl.Series or dict, but {type(obj)} is passed.'
)
if isinstance(obj, dict):
obj = pl.from_dict(obj)

if self._orient == 'records':
obj.write_ndjson(file)
else:
obj.write_json(file)


class PandasJsonFileProcessor(JsonFileProcessor):
def load(self, file):
try:
return pd.read_json(file, orient=self._orient, lines=True if self._orient == 'records' else False)
Expand Down Expand Up @@ -215,11 +269,27 @@ def __init__(self, engine='pyarrow', compression=None):
def format(self):
return luigi.format.Nop

def load(self, file): ...

def dump(self, obj, file): ...


class PolarsParquetFileProcessor(ParquetFileProcessor):
def load(self, file):
if ObjectStorage.is_buffered_reader(file):
return pl.read_parquet(file.name)
else:
return pl.read_parquet(BytesIO(file.read()))

def dump(self, obj, file):
assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.'
use_pyarrow = self._engine == 'pyarrow'
compression = 'uncompressed' if self._compression is None else self._compression
obj.write_parquet(file, use_pyarrow=use_pyarrow, compression=compression)


class PandasParquetFileProcessor(ParquetFileProcessor):
def load(self, file):
# FIXME(mamo3gr): enable streaming (chunked) read with S3.
# pandas.read_parquet accepts file-like object
# but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method,
# which is needed for pandas to read a file in chunks.
if ObjectStorage.is_buffered_reader(file):
return pd.read_parquet(file.name)
else:
Expand All @@ -240,6 +310,27 @@ def __init__(self, store_index_in_feather: bool):
def format(self):
return luigi.format.Nop

def load(self, file): ...

def dump(self, obj, file): ...


class PolarsFeatherFileProcessor(FeatherFileProcessor):
def load(self, file):
# Since polars' DataFrame doesn't have index, just load feather file
# TODO: Fix ingnoring store_index_in_feather variable
# Currently in PolarsFeatherFileProcessor, we ignored store_index_in_feather variable to avoid
# a breaking change of FeatherFileProcessor's default behavior.
if ObjectStorage.is_buffered_reader(file):
return pl.read_ipc(file.name)
return pl.read_ipc(BytesIO(file.read()))

def dump(self, obj, file):
assert isinstance(obj, (pl.DataFrame)), f'requires pl.DataFrame, but {type(obj)} is passed.'
obj.write_ipc(file.name)


class PandasFeatherFileProcessor(FeatherFileProcessor):
def load(self, file):
# FIXME(mamo3gr): enable streaming (chunked) read with S3.
# pandas.read_feather accepts file-like object
Expand Down Expand Up @@ -281,6 +372,18 @@ def dump(self, obj, file):
dump_obj.to_feather(file.name)


if DATAFRAME_FRAMEWORK == 'polars':
CsvFileProcessor = PolarsCsvFileProcessor # type: ignore
JsonFileProcessor = PolarsJsonFileProcessor # type: ignore
ParquetFileProcessor = PolarsParquetFileProcessor # type: ignore
FeatherFileProcessor = PolarsFeatherFileProcessor # type: ignore
else:
CsvFileProcessor = PandasCsvFileProcessor # type: ignore
JsonFileProcessor = PandasJsonFileProcessor # type: ignore
ParquetFileProcessor = PandasParquetFileProcessor # type: ignore
FeatherFileProcessor = PandasFeatherFileProcessor # type: ignore


def make_file_processor(file_path: str, store_index_in_feather: bool) -> FileProcessor:
extension2processor = {
'.txt': TextFileProcessor(),
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ Homepage = "https://github.com/m3dev/gokart"
Repository = "https://github.com/m3dev/gokart"
Documentation = "https://gokart.readthedocs.io/en/latest/"

[project.optional-dependencies]
polars = [
# polars doesn't run correctly on Apple silicon chips
"polars-lts-cpu; platform_system =='Darwin'",
"polars; platform_system != 'Darwin'",
]

[dependency-groups]
test = [
"fakeredis",
Expand Down
156 changes: 9 additions & 147 deletions test/test_file_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,126 +6,26 @@
from typing import Callable

import boto3
import pandas as pd
import pytest
from luigi import LocalTarget
from moto import mock_aws

from gokart.file_processor import (
DATAFRAME_FRAMEWORK,
CsvFileProcessor,
FeatherFileProcessor,
GzipFileProcessor,
JsonFileProcessor,
NpzFileProcessor,
PandasCsvFileProcessor,
ParquetFileProcessor,
PickleFileProcessor,
PolarsCsvFileProcessor,
TextFileProcessor,
make_file_processor,
)
from gokart.object_storage import ObjectStorage


class TestCsvFileProcessor(unittest.TestCase):
def test_dump_csv_with_utf8(self):
df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]})
processor = CsvFileProcessor()

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'

local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)

# read with utf-8 to check if the file is dumped with utf8
loaded_df = pd.read_csv(temp_path, encoding='utf-8')
pd.testing.assert_frame_equal(df, loaded_df)

def test_dump_csv_with_cp932(self):
df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]})
processor = CsvFileProcessor(encoding='cp932')

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'

local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)

# read with cp932 to check if the file is dumped with cp932
loaded_df = pd.read_csv(temp_path, encoding='cp932')
pd.testing.assert_frame_equal(df, loaded_df)

def test_load_csv_with_utf8(self):
df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]})
processor = CsvFileProcessor()

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'
df.to_csv(temp_path, encoding='utf-8', index=False)

local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('r') as f:
# read with utf-8 to check if the file is dumped with utf8
loaded_df = processor.load(f)
pd.testing.assert_frame_equal(df, loaded_df)

def test_load_csv_with_cp932(self):
df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]})
processor = CsvFileProcessor(encoding='cp932')

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'
df.to_csv(temp_path, encoding='cp932', index=False)

local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('r') as f:
# read with cp932 to check if the file is dumped with cp932
loaded_df = processor.load(f)
pd.testing.assert_frame_equal(df, loaded_df)


class TestJsonFileProcessor:
@pytest.mark.parametrize(
'orient,input_data,expected_json',
[
pytest.param(
None,
pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}),
'{"A":{"0":1,"1":2,"2":3},"B":{"0":4,"1":5,"2":6}}',
id='With Default Orient for DataFrame',
),
pytest.param(
'records',
pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}),
'{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n',
id='With Records Orient for DataFrame',
),
pytest.param(None, {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":{"0":1,"1":2,"2":3},"B":{"0":4,"1":5,"2":6}}', id='With Default Orient for Dict'),
pytest.param('records', {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n', id='With Records Orient for Dict'),
pytest.param(None, {}, '{}', id='With Default Orient for Empty Dict'),
pytest.param('records', {}, '\n', id='With Records Orient for Empty Dict'),
],
)
def test_dump_and_load_json(self, orient, input_data, expected_json):
processor = JsonFileProcessor(orient=orient)

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.json'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(input_data, f)
with local_target.open('r') as f:
loaded_df = processor.load(f)
f.seek(0)
loaded_json = f.read().decode('utf-8')

assert loaded_json == expected_json

df_input = pd.DataFrame(input_data)
pd.testing.assert_frame_equal(df_input, loaded_df)


class TestPickleFileProcessor(unittest.TestCase):
def test_dump_and_load_normal_obj(self):
var = 'abc'
Expand Down Expand Up @@ -189,50 +89,12 @@ def test_dump_and_load_with_readables3file(self):
self.assertEqual(loaded, var)


class TestFeatherFileProcessor(unittest.TestCase):
def test_feather_should_return_same_dataframe(self):
df = pd.DataFrame({'a': [1]})
processor = FeatherFileProcessor(store_index_in_feather=True)

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'

local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)

with local_target.open('r') as f:
loaded_df = processor.load(f)

pd.testing.assert_frame_equal(df, loaded_df)

def test_feather_should_save_index_name(self):
df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='index_name'))
processor = FeatherFileProcessor(store_index_in_feather=True)

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'

local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)

with local_target.open('r') as f:
loaded_df = processor.load(f)

pd.testing.assert_frame_equal(df, loaded_df)

def test_feather_should_raise_error_index_name_is_None(self):
df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='None'))
processor = FeatherFileProcessor(store_index_in_feather=True)

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'

local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
with self.assertRaises(AssertionError):
processor.dump(df, f)
class TestFileProcessorClassSelection(unittest.TestCase):
def test_processor_selection(self):
if DATAFRAME_FRAMEWORK == 'polars':
self.assertTrue(issubclass(CsvFileProcessor, PolarsCsvFileProcessor))
else:
self.assertTrue(issubclass(CsvFileProcessor, PandasCsvFileProcessor))


class TestMakeFileProcessor(unittest.TestCase):
Expand Down
Loading