diff --git a/gokart/target.py b/gokart/target.py index 88b3c942..7d5263e8 100644 --- a/gokart/target.py +++ b/gokart/target.py @@ -5,7 +5,7 @@ from datetime import datetime from glob import glob from logging import getLogger -from typing import Any, Optional +from typing import Any, Callable, Optional import luigi import numpy as np @@ -78,10 +78,12 @@ def __init__( target: luigi.target.FileSystemTarget, processor: FileProcessor, task_lock_params: TaskLockParams, + validator: Callable[[Any], bool] = lambda x: True, ) -> None: self._target = target self._processor = processor self._task_lock_params = task_lock_params + self._validator = validator def _exists(self) -> bool: return self._target.exists() @@ -91,9 +93,16 @@ def _get_task_lock_params(self) -> TaskLockParams: def _load(self) -> Any: with self._target.open('r') as f: - return self._processor.load(f) + obj = self._processor.load(f) + if not self._validator(obj): + raise ValueError(f'Validator error: Loaded object is invalid: {obj}') + + return obj def _dump(self, obj) -> None: + if not self._validator(obj): + raise ValueError(f'Validator error: Dumped object is invalid: {obj}') + with self._target.open('w') as f: self._processor.dump(obj, f) @@ -216,12 +225,13 @@ def make_target( processor: Optional[FileProcessor] = None, task_lock_params: Optional[TaskLockParams] = None, store_index_in_feather: bool = True, + validator: Callable[[Any], bool] = lambda x: True, ) -> TargetOnKart: _task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id) file_path = _make_file_path(file_path, unique_id) processor = processor or make_file_processor(file_path, store_index_in_feather=store_index_in_feather) file_system_target = _make_file_system_target(file_path, processor=processor, store_index_in_feather=store_index_in_feather) - return SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params) + return SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params, validator=validator) def make_model_target( diff --git a/gokart/task.py b/gokart/task.py index 0db1f875..060ecc1b 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -193,7 +193,13 @@ def clone(self, cls=None, **kwargs): return cls(**new_k) - def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, processor: Optional[FileProcessor] = None) -> TargetOnKart: + def make_target( + self, + relative_file_path: Optional[str] = None, + use_unique_id: bool = True, + processor: Optional[FileProcessor] = None, + validator: Callable[[Any], bool] = lambda x: True, + ) -> TargetOnKart: formatted_relative_file_path = ( relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.pkl') ) @@ -210,7 +216,12 @@ def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: b ) return gokart.target.make_target( - file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather + file_path=file_path, + unique_id=unique_id, + processor=processor, + task_lock_params=task_lock_params, + store_index_in_feather=self.store_index_in_feather, + validator=validator, ) def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart: diff --git a/poetry.lock b/poetry.lock index 4a68ec0b..13f9f4cf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1339,6 +1339,7 @@ optional = false python-versions = ">=3.9" files = [ {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, + {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, @@ -1359,6 +1360,7 @@ files = [ {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, @@ -1597,19 +1599,6 @@ files = [ {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, - {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, - {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, - {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, - {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, - {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, ] [package.dependencies] diff --git a/test/test_target.py b/test/test_target.py index 8090641f..fb53668d 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -1,18 +1,21 @@ import io import os import shutil +import tempfile import unittest from datetime import datetime from unittest.mock import patch import boto3 +import luigi import numpy as np import pandas as pd +import pandera as pa from matplotlib import pyplot from moto import mock_aws -from gokart.file_processor import _ChunkedLargeFileReader -from gokart.target import make_model_target, make_target +from gokart.file_processor import _ChunkedLargeFileReader, make_file_processor +from gokart.target import SingleFileTarget, make_model_target, make_target def _get_temporary_directory(): @@ -280,5 +283,42 @@ def test_model_target_on_s3(self): self.assertEqual(loaded, obj) +class SingleFileTargetTest(unittest.TestCase): + def test_typed_target(self): + def validate_dataframe(x): + return isinstance(x, pd.DataFrame) + + test_case = pd.DataFrame(dict(a=[1, 2])) + + with tempfile.TemporaryDirectory() as temp_dir: + _task_lock_params = None + file_path = os.path.join(temp_dir, 'test.pkl') + processor = make_file_processor(file_path, store_index_in_feather=False) + file_system_target = luigi.LocalTarget(file_path, format=processor.format()) + file_target = SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params, validator=validate_dataframe) + + file_target.dump(test_case) + dumped_data = file_target.load() + self.assertIsInstance(dumped_data, self.DummyDataFrameSchema) + + def test_invalid_typed_target(self): + def validate_int(x): + return isinstance(x, int) + + test_case = pd.DataFrame(dict(a=['1', '2'])) + + with tempfile.TemporaryDirectory() as temp_dir: + _task_lock_params = None + file_path = os.path.join(temp_dir, 'test.csv') + processor = make_file_processor(file_path, store_index_in_feather=False) + file_system_target = luigi.LocalTarget(file_path, format=processor.format()) + file_target = SingleFileTarget( + target=file_system_target, processor=processor, task_lock_params=_task_lock_params, expected_dataframe_type=validate_int + ) + + with self.assertRaises(pa.errors.SchemaError): + file_target.dump(test_case) + + if __name__ == '__main__': unittest.main()