diff --git a/gokart/in_memory/target.py b/gokart/in_memory/target.py index c1fd185a..03803d4b 100644 --- a/gokart/in_memory/target.py +++ b/gokart/in_memory/target.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any +from typing import Any, Optional from gokart.in_memory.repository import InMemoryCacheRepository from gokart.target import TargetOnKart, TaskLockParams @@ -41,5 +41,12 @@ def _path(self) -> str: return self._data_key -def make_in_memory_target(target_key: str, task_lock_params: TaskLockParams) -> InMemoryTarget: - return InMemoryTarget(target_key, task_lock_params) +def _make_data_key(data_key: str, unique_id: Optional[str] = None) -> str: + if not unique_id: + return data_key + return data_key + '_' + unique_id + + +def make_in_memory_target(data_key: str, task_lock_params: TaskLockParams, unique_id: Optional[str] = None) -> InMemoryTarget: + _data_key = _make_data_key(data_key, unique_id) + return InMemoryTarget(_data_key, task_lock_params) diff --git a/gokart/task.py b/gokart/task.py index f577f64b..77080cb1 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -20,9 +20,10 @@ import gokart import gokart.target -from gokart.conflict_prevention_lock.task_lock import make_task_lock_params, make_task_lock_params_for_run +from gokart.conflict_prevention_lock.task_lock import TaskLockParams, make_task_lock_params, make_task_lock_params_for_run from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_run_with_lock from gokart.file_processor import FileProcessor +from gokart.in_memory.target import InMemoryTarget, make_in_memory_target from gokart.pandas_type_config import PandasTypeConfigMap from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter from gokart.target import TargetOnKart @@ -105,6 +106,9 @@ class TaskOnKart(luigi.Task, Generic[T]): default=True, description='Check if output file exists at run. If exists, run() will be skipped.', significant=False ) should_lock_run: bool = ExplicitBoolParameter(default=False, significant=False, description='Whether to use redis lock or not at task run.') + cache_in_memory_by_default: bool = ExplicitBoolParameter( + default=False, significant=False, description='If `True`, output is stored on a memory instead of files unless specified.' + ) @property def priority(self): @@ -134,11 +138,13 @@ def __init__(self, *args, **kwargs): task_lock_params = make_task_lock_params_for_run(task_self=self) self.run = wrap_run_with_lock(run_func=self.run, task_lock_params=task_lock_params) # type: ignore + self.make_default_target = self.make_target if not self.cache_in_memory_by_default else self.make_cache_target + def input(self) -> FlattenableItems[TargetOnKart]: return super().input() def output(self) -> FlattenableItems[TargetOnKart]: - return self.make_target() + return self.make_default_target() def requires(self) -> FlattenableItems['TaskOnKart']: tasks = self.make_task_instance_dictionary() @@ -229,6 +235,21 @@ def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: b file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather ) + def make_cache_target(self, data_key: Optional[str] = None, use_unique_id: bool = True) -> InMemoryTarget: + _data_key = data_key if data_key else os.path.join(self.__module__.replace('.', '/'), type(self).__name__) + unique_id = self.make_unique_id() if use_unique_id else None + # TODO: combine with redis + task_lock_params = TaskLockParams( + redis_host=None, + redis_port=None, + redis_timeout=None, + redis_key='redis_key', + should_task_lock=False, + raise_task_lock_exception_on_collision=False, + lock_extend_seconds=-1, + ) + return make_in_memory_target(_data_key, task_lock_params, unique_id) + def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart: formatted_relative_file_path = ( relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.zip') diff --git a/test/in_memory/test_in_memory_target.py b/test/in_memory/test_in_memory_target.py index ae6ca11d..a6efccd0 100644 --- a/test/in_memory/test_in_memory_target.py +++ b/test/in_memory/test_in_memory_target.py @@ -22,7 +22,7 @@ def task_lock_params(self) -> TaskLockParams: @pytest.fixture def target(self, task_lock_params: TaskLockParams) -> InMemoryTarget: - return make_in_memory_target(target_key='dummy_key', task_lock_params=task_lock_params) + return make_in_memory_target(data_key='dummy_key', task_lock_params=task_lock_params) @pytest.fixture(autouse=True) def clear_repo(self) -> None: diff --git a/test/in_memory/test_task_cached_in_memory.py b/test/in_memory/test_task_cached_in_memory.py new file mode 100644 index 00000000..2d09a754 --- /dev/null +++ b/test/in_memory/test_task_cached_in_memory.py @@ -0,0 +1,118 @@ +from typing import Optional, Type, Union + +import luigi +import pytest + +import gokart +from gokart.in_memory import InMemoryCacheRepository, InMemoryTarget +from gokart.target import SingleFileTarget + + +class DummyTask(gokart.TaskOnKart): + task_namespace = __name__ + param: str = luigi.Parameter() + + def run(self): + self.dump(self.param) + + +class DummyTaskWithDependencies(gokart.TaskOnKart): + task_namespace = __name__ + task: list[gokart.TaskOnKart[str]] = gokart.ListTaskInstanceParameter() + + def run(self): + result = ','.join(self.load()) + self.dump(result) + + +class DumpIntTask(gokart.TaskOnKart[int]): + task_namespace = __name__ + value: int = luigi.IntParameter() + + def run(self): + self.dump(self.value) + + +class AddTask(gokart.TaskOnKart[Union[int, float]]): + a: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter() + b: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter() + + def requires(self): + return dict(a=self.a, b=self.b) + + def run(self): + a = self.load(self.a) + b = self.load(self.b) + self.dump(a + b) + + +class TestTaskOnKartWithCache: + @pytest.fixture(autouse=True) + def clear_repository(self) -> None: + InMemoryCacheRepository().clear() + + @pytest.mark.parametrize('data_key', ['sample_key', None]) + @pytest.mark.parametrize('use_unique_id', [True, False]) + def test_key_identity(self, data_key: Optional[str], use_unique_id: bool): + task = DummyTask(param='param') + ext = '.pkl' + relative_file_path = data_key + ext if data_key else None + target = task.make_target(relative_file_path=relative_file_path, use_unique_id=use_unique_id) + cached_target = task.make_cache_target(data_key=data_key, use_unique_id=use_unique_id) + + target_path = target.path().removeprefix(task.workspace_directory).removesuffix(ext).strip('/') + assert cached_target.path() == target_path + + def test_make_cached_target(self): + task = DummyTask(param='param') + target = task.make_cache_target() + assert isinstance(target, InMemoryTarget) + + @pytest.mark.parametrize(['cache_in_memory_by_default', 'target_type'], [[True, InMemoryTarget], [False, SingleFileTarget]]) + def test_make_default_target(self, cache_in_memory_by_default: bool, target_type: Type[gokart.TaskOnKart]): + task = DummyTask(param='param', cache_in_memory_by_default=cache_in_memory_by_default) + target = task.output() + assert isinstance(target, target_type) + + def test_complete_with_cache_in_memory_flag(self, tmpdir): + task = DummyTask(param='param', cache_in_memory_by_default=True, workspace_directory=tmpdir) + assert not task.complete() + file_target = task.make_target() + file_target.dump('data') + assert not task.complete() + cache_target = task.make_cache_target() + cache_target.dump('data') + assert task.complete() + + def test_complete_without_cache_in_memory_flag(self, tmpdir): + task = DummyTask(param='param', workspace_directory=tmpdir) + assert not task.complete() + cache_target = task.make_cache_target() + cache_target.dump('data') + assert not task.complete() + file_target = task.make_target() + file_target.dump('data') + assert task.complete() + + def test_dump_with_cache_in_memory_flag(self, tmpdir): + task = DummyTask(param='param', cache_in_memory_by_default=True, workspace_directory=tmpdir) + file_target = task.make_target() + cache_target = task.make_cache_target() + task.dump('data') + assert not file_target.exists() + assert cache_target.exists() + + def test_dump_without_cache_in_memory_flag(self, tmpdir): + task = DummyTask(param='param', workspace_directory=tmpdir) + file_target = task.make_target() + cache_target = task.make_cache_target() + task.dump('data') + assert file_target.exists() + assert not cache_target.exists() + + def test_gokart_build(self): + task = AddTask( + a=DumpIntTask(value=2, cache_in_memory_by_default=True), b=DumpIntTask(value=3, cache_in_memory_by_default=True), cache_in_memory_by_default=True + ) + output = gokart.build(task, reset_register=False) + assert output == 5