Skip to content

Commit b200fd9

Browse files
committed
add fixture for reloading the modules
1 parent 79839e1 commit b200fd9

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

test/test_file_processor_polars.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,37 @@
11
from __future__ import annotations
22

3+
import importlib
34
import tempfile
45

56
import pytest
67
from luigi import LocalTarget
78

9+
from gokart import file_processor
810
from gokart.file_processor import PolarsCsvFileProcessor, PolarsFeatherFileProcessor, PolarsJsonFileProcessor
911

1012
pl = pytest.importorskip('polars', reason='polars required')
1113
pl_testing = pytest.importorskip('polars.testing', reason='polars required')
1214

1315

14-
def test_dump_csv_with():
16+
@pytest.fixture
17+
def reload_processor(monkeypatch):
18+
"""
19+
A pytest fixture that reloads the `gokart.file_processor` module after modifying
20+
the environment variable `GOKART_DATAFRAME_FRAMEWORK`. This ensures that polars
21+
is used when reloading the module.
22+
23+
Returns:
24+
Tuple[Type[PolarsCsvFileProcessor], Type[PolarsFeatherFileProcessor], Type[PolarsJsonFileProcessor]]:
25+
The reloaded classes from the `gokart.file_processor` module.
26+
"""
27+
monkeypatch.setenv('GOKART_DATAFRAME_FRAMEWORK', 'polars')
28+
importlib.reload(file_processor)
29+
30+
yield PolarsCsvFileProcessor, PolarsFeatherFileProcessor, PolarsJsonFileProcessor
31+
32+
33+
def test_dump_csv(reload_processor):
34+
PolarsCsvFileProcessor, _, _ = reload_processor
1535
df = pl.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]})
1636
processor = PolarsCsvFileProcessor()
1737

@@ -27,7 +47,8 @@ def test_dump_csv_with():
2747
pl_testing.assert_frame_equal(df, loaded_df)
2848

2949

30-
def test_load_csv():
50+
def test_load_csv(reload_processor):
51+
PolarsCsvFileProcessor, _, _ = reload_processor
3152
df = pl.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]})
3253
processor = PolarsCsvFileProcessor()
3354

@@ -63,7 +84,8 @@ def test_load_csv():
6384
pytest.param('records', {}, '', id='With Records Orient for Empty Dict'),
6485
],
6586
)
66-
def test_dump_and_load_json(orient, input_data, expected_json):
87+
def test_dump_and_load_json(reload_processor, orient, input_data, expected_json):
88+
_, _, PolarsJsonFileProcessor = reload_processor
6789
processor = PolarsJsonFileProcessor(orient=orient)
6890

6991
with tempfile.TemporaryDirectory() as temp_dir:
@@ -82,7 +104,8 @@ def test_dump_and_load_json(orient, input_data, expected_json):
82104
pl_testing.assert_frame_equal(df_input, loaded_df)
83105

84106

85-
def test_feather_should_return_same_dataframe():
107+
def test_feather_should_return_same_dataframe(reload_processor):
108+
_, PolarsFeatherFileProcessor, _ = reload_processor
86109
df = pl.DataFrame({'a': [1]})
87110
# TODO: currently we set store_index_in_feather True but it is ignored
88111
processor = PolarsFeatherFileProcessor(store_index_in_feather=True)

0 commit comments

Comments
 (0)