diff --git a/parsl/benchmark/perf.py b/parsl/benchmark/perf.py index 803476385e..6502989449 100644 --- a/parsl/benchmark/perf.py +++ b/parsl/benchmark/perf.py @@ -49,7 +49,7 @@ def performance(*, resources: dict, target_t: float, args_extra_size: int, itera iteration = 1 - args_extra_payload = "x" * args_extra_size + # args_extra_payload = "x" * args_extra_size if isinstance(iterate_mode, list): n = iterate_mode[0] @@ -65,7 +65,13 @@ def performance(*, resources: dict, target_t: float, args_extra_size: int, itera fs = [] print("Submitting tasks / invoking apps") - for _ in range(n): + for index in range(n): + # this means there is a different argument for each iteration, + # which will make checkpointing/memo behave differently + # so this could be switchable in parsl-perf dev branch + # args_extra_payload = index # always a new one (except for run repeats) + + args_extra_payload = index % 10 fs.append(app(args_extra_payload, parsl_resource_specification=resources)) submitted_t = time.time() diff --git a/parsl/config.py b/parsl/config.py index 1358e99d28..8cf3e69a6f 100644 --- a/parsl/config.py +++ b/parsl/config.py @@ -5,6 +5,7 @@ from typing_extensions import Literal from parsl.dataflow.dependency_resolvers import DependencyResolver +from parsl.dataflow.memoization import Memoizer from parsl.dataflow.taskrecord import TaskRecord from parsl.errors import ConfigurationError from parsl.executors.base import ParslExecutor @@ -27,17 +28,6 @@ class Config(RepresentationMixin, UsageInformation): executors : sequence of ParslExecutor, optional List (or other iterable) of `ParslExecutor` instances to use for executing tasks. Default is (:class:`~parsl.executors.threads.ThreadPoolExecutor()`,). - app_cache : bool, optional - Enable app caching. Default is True. - checkpoint_files : sequence of str, optional - List of paths to checkpoint files. See :func:`parsl.utils.get_all_checkpoints` and - :func:`parsl.utils.get_last_checkpoint` for helpers. Default is None. - checkpoint_mode : str, optional - Checkpoint mode to use, can be ``'dfk_exit'``, ``'task_exit'``, ``'periodic'`` or ``'manual'``. - If set to `None`, checkpointing will be disabled. Default is None. - checkpoint_period : str, optional - Time interval (in "HH:MM:SS") at which to checkpoint completed tasks. Only has an effect if - ``checkpoint_mode='periodic'``. dependency_resolver: plugin point for custom dependency resolvers. Default: only resolve Futures, using the `SHALLOW_DEPENDENCY_RESOLVER`. exit_mode: str, optional @@ -100,14 +90,7 @@ class Config(RepresentationMixin, UsageInformation): @typeguard.typechecked def __init__(self, executors: Optional[Iterable[ParslExecutor]] = None, - app_cache: bool = True, - checkpoint_files: Optional[Sequence[str]] = None, - checkpoint_mode: Union[None, - Literal['task_exit'], - Literal['periodic'], - Literal['dfk_exit'], - Literal['manual']] = None, - checkpoint_period: Optional[str] = None, + memoizer: Optional[Memoizer] = None, dependency_resolver: Optional[DependencyResolver] = None, exit_mode: Literal['cleanup', 'skip', 'wait'] = 'cleanup', garbage_collect: bool = True, @@ -131,21 +114,7 @@ def __init__(self, self._executors: Sequence[ParslExecutor] = executors self._validate_executors() - self.app_cache = app_cache - self.checkpoint_files = checkpoint_files - self.checkpoint_mode = checkpoint_mode - if checkpoint_period is not None: - if checkpoint_mode is None: - logger.debug('The requested `checkpoint_period={}` will have no effect because `checkpoint_mode=None`'.format( - checkpoint_period) - ) - elif checkpoint_mode != 'periodic': - logger.debug("Requested checkpoint period of {} only has an effect with checkpoint_mode='periodic'".format( - checkpoint_period) - ) - if checkpoint_mode == 'periodic' and checkpoint_period is None: - checkpoint_period = "00:30:00" - self.checkpoint_period = checkpoint_period + self.memoizer = memoizer self.dependency_resolver = dependency_resolver self.exit_mode = exit_mode self.garbage_collect = garbage_collect diff --git a/parsl/configs/ASPIRE1.py b/parsl/configs/ASPIRE1.py index 017e1061d7..84d131f210 100644 --- a/parsl/configs/ASPIRE1.py +++ b/parsl/configs/ASPIRE1.py @@ -1,5 +1,6 @@ from parsl.addresses import address_by_interface from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer from parsl.executors import HighThroughputExecutor from parsl.launchers import MpiRunLauncher from parsl.monitoring.monitoring import MonitoringHub @@ -38,7 +39,6 @@ ), strategy='simple', retries=3, - app_cache=True, - checkpoint_mode='task_exit', + memoizer=BasicMemoizer(checkpoint_mode='task_exit'), usage_tracking=LEVEL_1, ) diff --git a/parsl/dataflow/dflow.py b/parsl/dataflow/dflow.py index 7af511d5bc..dd462f4630 100644 --- a/parsl/dataflow/dflow.py +++ b/parsl/dataflow/dflow.py @@ -29,7 +29,7 @@ from parsl.dataflow.dependency_resolvers import SHALLOW_DEPENDENCY_RESOLVER from parsl.dataflow.errors import DependencyError, JoinError from parsl.dataflow.futures import AppFuture -from parsl.dataflow.memoization import Memoizer +from parsl.dataflow.memoization import BasicMemoizer, Memoizer from parsl.dataflow.rundirs import make_rundir from parsl.dataflow.states import FINAL_FAILURE_STATES, FINAL_STATES, States from parsl.dataflow.taskrecord import TaskRecord @@ -165,12 +165,8 @@ def __init__(self, config: Config) -> None: self.monitoring_radio.send((MessageType.WORKFLOW_INFO, workflow_info)) - self.memoizer = Memoizer(memoize=config.app_cache, - checkpoint_mode=config.checkpoint_mode, - checkpoint_files=config.checkpoint_files, - checkpoint_period=config.checkpoint_period) - self.memoizer.run_dir = self.run_dir - self.memoizer.start() + self.memoizer: Memoizer = config.memoizer if config.memoizer is not None else BasicMemoizer() + self.memoizer.start(run_dir=self.run_dir) # this must be set before executors are added since add_executors calls # job_status_poller.add_executors. @@ -502,31 +498,6 @@ def handle_join_update(self, task_record: TaskRecord, inner_app_future: Optional self._log_std_streams(task_record) - def handle_app_update(self, task_record: TaskRecord, future: AppFuture) -> None: - """This function is called as a callback when an AppFuture - is in its final state. - - It will trigger post-app processing such as checkpointing. - - Args: - task_record : Task record - future (Future) : The relevant app future (which should be - consistent with the task structure 'app_fu' entry - - """ - - task_id = task_record['id'] - - if not task_record['app_fu'].done(): - logger.error("Internal consistency error: app_fu is not done for task {}".format(task_id)) - if not task_record['app_fu'] == future: - logger.error("Internal consistency error: callback future is not the app_fu in task structure, for task {}".format(task_id)) - - self.memoizer.update_checkpoint(task_record) - - self.wipe_task(task_id) - return - def _complete_task_result(self, task_record: TaskRecord, new_state: States, result: Any) -> None: """Set a task into a completed state """ @@ -543,6 +514,8 @@ def _complete_task_result(self, task_record: TaskRecord, new_state: States, resu self._send_task_log_info(task_record) + self.wipe_task(task_record['id']) + with task_record['app_fu']._update_lock: task_record['app_fu'].set_result(result) @@ -562,6 +535,8 @@ def _complete_task_exception(self, task_record: TaskRecord, new_state: States, e self._send_task_log_info(task_record) + self.wipe_task(task_record['id']) + with task_record['app_fu']._update_lock: task_record['app_fu'].set_exception(exception) @@ -1053,7 +1028,6 @@ def submit(self, task_record['func_name'], waiting_message)) - app_fu.add_done_callback(partial(self.handle_app_update, task_record)) self._update_task_state(task_record, States.pending) logger.debug("Task {} set to pending state with AppFuture: {}".format(task_id, task_record['app_fu'])) @@ -1227,9 +1201,6 @@ def cleanup(self) -> None: # should still see it. logger.info("DFK cleanup complete") - def checkpoint(self) -> None: - self.memoizer.checkpoint() - @staticmethod def _log_std_streams(task_record: TaskRecord) -> None: tid = task_record['id'] diff --git a/parsl/dataflow/memoization.py b/parsl/dataflow/memoization.py index b68422beac..af30eb1cca 100644 --- a/parsl/dataflow/memoization.py +++ b/parsl/dataflow/memoization.py @@ -8,17 +8,25 @@ import types from concurrent.futures import Future from functools import lru_cache, singledispatch -from typing import Any, Dict, List, Literal, Optional, Sequence +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple import typeguard from parsl.dataflow.errors import BadCheckpoint from parsl.dataflow.taskrecord import TaskRecord -from parsl.errors import ConfigurationError, InternalConsistencyError +from parsl.errors import ConfigurationError from parsl.utils import Timer, get_all_checkpoints logger = logging.getLogger(__name__) +# there's an XOR rule that isn't expressed in the type +# system: exactly one of object or BaseException must +# be set. This could be made into something more dataclass-like +# to give the fields names. Except we can't distinguish +# between None as a result and that not being set - +# so the None-ness of the exception is actually what is the differentiator. +CheckpointCommand = Tuple[TaskRecord, Optional[object], Optional[BaseException]] + @singledispatch def id_for_memo(obj: object, output_ref: bool = False) -> bytes: @@ -118,7 +126,60 @@ def id_for_memo_function(f: types.FunctionType, output_ref: bool = False) -> byt return pickle.dumps(["types.FunctionType", f.__name__, f.__module__]) +def make_hash(task: TaskRecord) -> str: + """Create a hash of the task inputs. + + Args: + - task (dict) : Task dictionary from dfk.tasks + + Returns: + - hash (str) : A unique hash string + """ + + t: List[bytes] = [] + + # if kwargs contains an outputs parameter, that parameter is removed + # and normalised differently - with output_ref set to True. + # kwargs listed in ignore_for_cache will also be removed + + filtered_kw = task['kwargs'].copy() + + ignore_list = task['ignore_for_cache'] + + logger.debug("Ignoring these kwargs for checkpointing: %s", ignore_list) + for k in ignore_list: + logger.debug("Ignoring kwarg %s", k) + del filtered_kw[k] + + if 'outputs' in task['kwargs']: + outputs = task['kwargs']['outputs'] + del filtered_kw['outputs'] + t.append(id_for_memo(outputs, output_ref=True)) + + t.extend(map(id_for_memo, (filtered_kw, task['func'], task['args']))) + + x = b''.join(t) + return hashlib.md5(x).hexdigest() + + class Memoizer: + def update_memo_exception(self, task: TaskRecord, e: BaseException) -> None: + raise NotImplementedError + + def update_memo_result(self, task: TaskRecord, r: Any) -> None: + raise NotImplementedError + + def start(self, *, run_dir: str) -> None: + raise NotImplementedError + + def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]: + raise NotImplementedError + + def close(self) -> None: + raise NotImplementedError + + +class BasicMemoizer(Memoizer): """Memoizer is responsible for ensuring that identical work is not repeated. When a task is repeated, i.e., the same function is called with the same exact arguments, the @@ -152,18 +213,17 @@ class Memoizer: run_dir: str def __init__(self, *, - memoize: bool = True, - checkpoint_files: Sequence[str] | None, - checkpoint_period: Optional[str], - checkpoint_mode: Literal['task_exit', 'periodic', 'dfk_exit', 'manual'] | None): + checkpoint_files: Sequence[str] | None = None, + checkpoint_period: Optional[str] = None, + checkpoint_mode: Literal['task_exit', 'periodic', 'dfk_exit', 'manual'] | None = None, + memoize: bool = True): # TODO: unlikely to need to set this to false, but it was in config API before... """Initialize the memoizer. KWargs: - memoize (Bool): enable memoization or not. - checkpoint (Dict): A checkpoint loaded as a dict. + TODO: update """ - self.memoize = memoize - self.checkpointed_tasks = 0 self.checkpoint_lock = threading.Lock() @@ -172,11 +232,15 @@ def __init__(self, *, self.checkpoint_mode = checkpoint_mode self.checkpoint_period = checkpoint_period - self.checkpointable_tasks: List[TaskRecord] = [] + self.checkpointable_tasks: List[CheckpointCommand] = [] self._checkpoint_timer: Timer | None = None + self.memoize = memoize + + def start(self, *, run_dir: str) -> None: + + self.run_dir = run_dir - def start(self) -> None: if self.checkpoint_files is not None: checkpoint_files = self.checkpoint_files elif self.checkpoint_files is None and self.checkpoint_mode is not None: @@ -202,52 +266,17 @@ def start(self) -> None: except Exception: raise ConfigurationError("invalid checkpoint_period provided: {0} expected HH:MM:SS".format(self.checkpoint_period)) checkpoint_period = (h * 3600) + (m * 60) + s - self._checkpoint_timer = Timer(self.checkpoint, interval=checkpoint_period, name="Checkpoint") + self._checkpoint_timer = Timer(self.checkpoint_queue, interval=checkpoint_period, name="Checkpoint") def close(self) -> None: if self.checkpoint_mode is not None: logger.info("Making final checkpoint") - self.checkpoint() + self.checkpoint_queue() if self._checkpoint_timer: logger.info("Stopping checkpoint timer") self._checkpoint_timer.close() - def make_hash(self, task: TaskRecord) -> str: - """Create a hash of the task inputs. - - Args: - - task (dict) : Task dictionary from dfk.tasks - - Returns: - - hash (str) : A unique hash string - """ - - t: List[bytes] = [] - - # if kwargs contains an outputs parameter, that parameter is removed - # and normalised differently - with output_ref set to True. - # kwargs listed in ignore_for_cache will also be removed - - filtered_kw = task['kwargs'].copy() - - ignore_list = task['ignore_for_cache'] - - logger.debug("Ignoring these kwargs for checkpointing: %s", ignore_list) - for k in ignore_list: - logger.debug("Ignoring kwarg %s", k) - del filtered_kw[k] - - if 'outputs' in task['kwargs']: - outputs = task['kwargs']['outputs'] - del filtered_kw['outputs'] - t.append(id_for_memo(outputs, output_ref=True)) - - t.extend(map(id_for_memo, (filtered_kw, task['func'], task['args']))) - - x = b''.join(t) - return hashlib.md5(x).hexdigest() - def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]: """Create a hash of the task and its inputs and check the lookup table for this hash. @@ -269,7 +298,7 @@ def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]: logger.debug("Task {} will not be memoized".format(task_id)) return None - hashsum = self.make_hash(task) + hashsum = make_hash(task) logger.debug("Task {} has memoization hash {}".format(task_id, hashsum)) result = None if hashsum in self.memo_lookup_table: @@ -286,9 +315,29 @@ def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]: def update_memo_result(self, task: TaskRecord, r: Any) -> None: self._update_memo(task) + if self.checkpoint_mode == 'task_exit': + self.checkpoint_one((task, r, None)) + elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'): + # with self._modify_checkpointable_tasks_lock: # TODO: sort out use of this lock + self.checkpointable_tasks.append((task, r, None)) + elif self.checkpoint_mode is None: + pass + else: + assert False, "Invalid checkpoint mode {self.checkpoint_mode} - should have been validated at initialization" + def update_memo_exception(self, task: TaskRecord, e: BaseException) -> None: self._update_memo(task) + if self.checkpoint_mode == 'task_exit': + self.checkpoint_one((task, None, e)) + elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'): + # with self._modify_checkpointable_tasks_lock: # TODO: sort out use of this lock + self.checkpointable_tasks.append((task, None, e)) + elif self.checkpoint_mode is None: + pass + else: + assert False, "Invalid checkpoint mode {self.checkpoint_mode} - should have been validated at initialization" + def _update_memo(self, task: TaskRecord) -> None: """Updates the memoization lookup table with the result from a task. This doesn't move any values around but associates the memoization @@ -340,8 +389,12 @@ def _load_checkpoints(self, checkpointDirs: Sequence[str]) -> Dict[str, Future[A data = pickle.load(f) # Copy and hash only the input attributes memo_fu: Future = Future() - assert data['exception'] is None - memo_fu.set_result(data['result']) + + if data['exception'] is None: + memo_fu.set_result(data['result']) + else: + assert data['result'] is None + memo_fu.set_exception(data['exception']) memo_lookup_table[data['hash']] = memo_fu except EOFError: @@ -380,77 +433,84 @@ def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str, else: return {} - def update_checkpoint(self, task_record: TaskRecord) -> None: - if self.checkpoint_mode == 'task_exit': - self.checkpoint(task=task_record) - elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'): - with self.checkpoint_lock: - self.checkpointable_tasks.append(task_record) - elif self.checkpoint_mode is None: - pass - else: - raise InternalConsistencyError(f"Invalid checkpoint mode {self.checkpoint_mode}") - - def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None: - """Checkpoint the dfk incrementally to a checkpoint file. - - When called with no argument, all tasks registered in self.checkpointable_tasks - will be checkpointed. When called with a single TaskRecord argument, that task will be - checkpointed. + def checkpoint_one(self, cc: CheckpointCommand) -> None: + """Checkpoint a single task to a checkpoint file. By default the checkpoints are written to the RUNDIR of the current run under RUNDIR/checkpoints/tasks.pkl Kwargs: - - task (Optional task records) : A task to checkpoint. Default=None, meaning all - tasks registered for checkpointing. + - task : A task to checkpoint. .. note:: Checkpointing only works if memoization is enabled """ with self.checkpoint_lock: + self._checkpoint_these_tasks([cc]) - if task: - checkpoint_queue = [task] - else: - checkpoint_queue = self.checkpointable_tasks - - checkpoint_dir = '{0}/checkpoint'.format(self.run_dir) - checkpoint_tasks = checkpoint_dir + '/tasks.pkl' - - if not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir, exist_ok=True) + def checkpoint_queue(self) -> None: + """Checkpoint all tasks registered in self.checkpointable_tasks. - count = 0 + By default the checkpoints are written to the RUNDIR of the current + run under RUNDIR/checkpoints/tasks.pkl - with open(checkpoint_tasks, 'ab') as f: - for task_record in checkpoint_queue: - task_id = task_record['id'] + .. note:: + Checkpointing only works if memoization is enabled + """ + with self.checkpoint_lock: + self._checkpoint_these_tasks(self.checkpointable_tasks) + self.checkpointable_tasks = [] - app_fu = task_record['app_fu'] + def checkpoint(self) -> None: + """This is the user-facing interface to manual checkpointing. + """ + self.checkpoint_queue() - if app_fu.done() and app_fu.exception() is None: - hashsum = task_record['hashsum'] - if not hashsum: - continue - t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()} + def _checkpoint_these_tasks(self, checkpoint_queue: List[CheckpointCommand]) -> None: + """Play a sequence of CheckpointCommands into a checkpoint file. - # We are using pickle here since pickle dumps to a file in 'ab' - # mode behave like a incremental log. - pickle.dump(t, f) - count += 1 - logger.debug("Task {} checkpointed".format(task_id)) + The checkpoint lock must be held when invoking this method. + """ + checkpoint_dir = '{0}/checkpoint'.format(self.run_dir) + checkpoint_tasks = checkpoint_dir + '/tasks.pkl' + + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir, exist_ok=True) + + count = 0 + + with open(checkpoint_tasks, 'ab') as f: + for cc in checkpoint_queue: + (task_record, result, exception) = cc + + if exception is None and self.filter_result_for_checkpoint(result): + t = {'hash': task_record['hashsum'], 'exception': None, 'result': result} + pickle.dump(t, f) + count += 1 + logger.debug("Task %s checkpointed result", task_record['id']) + elif exception is not None and self.filter_exception_for_checkpoint(exception): + t = {'hash': task_record['hashsum'], 'exception': exception, 'result': None} + pickle.dump(t, f) + count += 1 + logger.debug("Task %s checkpointed exception", task_record['id']) + else: + pass # TODO: maybe log at debug level - self.checkpointed_tasks += count + self.checkpointed_tasks += count - if count == 0: - if self.checkpointed_tasks == 0: - logger.warning("No tasks checkpointed so far in this run. Please ensure caching is enabled") - else: - logger.debug("No tasks checkpointed in this pass.") + if count == 0: + if self.checkpointed_tasks == 0: + logger.warning("No tasks checkpointed so far in this run. Please ensure caching is enabled") else: - logger.info("Done checkpointing {} tasks".format(count)) + logger.debug("No tasks checkpointed in this pass.") + else: + logger.info("Done checkpointing {} tasks".format(count)) + + def filter_result_for_checkpoint(self, result: Any) -> bool: + """Overridable method to decide if an task that ended with a successful result should be checkpointed""" + return True - if not task: - self.checkpointable_tasks = [] + def filter_exception_for_checkpoint(self, exception: BaseException) -> bool: + """Overridable method to decide if an entry that ended with an exception should be checkpointed""" + return False diff --git a/parsl/dataflow/memosql.py b/parsl/dataflow/memosql.py new file mode 100644 index 0000000000..002cd0d3b6 --- /dev/null +++ b/parsl/dataflow/memosql.py @@ -0,0 +1,140 @@ +import logging +import pickle +import sqlite3 +from concurrent.futures import Future +from pathlib import Path +from typing import Any, Optional + +from parsl.dataflow.memoization import Memoizer, make_hash +from parsl.dataflow.taskrecord import TaskRecord + +logger = logging.getLogger(__name__) + + +class SQLiteMemoizer(Memoizer): + """Memoize out of memory into an sqlite3 database. + """ + + def __init__(self, *, checkpoint_dir: str | None = None): + self.checkpoint_dir = checkpoint_dir + + def start(self, *, run_dir: str) -> None: + """TODO: run_dir is the per-workflow run dir, but we need a broader checkpoint context... one level up + by default... get_all_checkpoints uses "runinfo/" as a relative path for that by default so replicating + that choice would do here. likewise I think for monitoring.""" + + self.run_dir = run_dir + + dir = self.checkpoint_dir if self.checkpoint_dir is not None else self.run_dir + + self.db_path = Path(dir) / "checkpoint.sqlite3" + logger.debug("starting with db_path %r", self.db_path) + + connection = sqlite3.connect(self.db_path) + cursor = connection.cursor() + + cursor.execute("CREATE TABLE IF NOT EXISTS checkpoints(key, result)") + # probably want some index on key because that's what we're doing all the access via. + + connection.commit() + connection.close() + logger.debug("checkpoint table created") + + def close(self): + """TODO: probably going to need some kind of shutdown now, to close the sqlite3 connection.""" + pass + + def check_memo(self, task: TaskRecord) -> Optional[Future]: + """TODO: document this: check_memo is required to set the task hashsum, + if that's how we're going to key checkpoints in update_memo. (that's not + a requirement though: other equalities are available.""" + task_id = task['id'] + + if not task['memoize']: + task['hashsum'] = None + logger.debug("Task %s will not be memoized", task_id) + return None + + hashsum = make_hash(task) + logger.debug("Task {} has memoization hash {}".format(task_id, hashsum)) + task['hashsum'] = hashsum + + connection = sqlite3.connect(self.db_path) + cursor = connection.cursor() + cursor.execute("SELECT result FROM checkpoints WHERE key = ?", (hashsum, )) + r = cursor.fetchone() + + if r is None: + connection.close() + return None + else: + data = pickle.loads(r[0]) + connection.close() + + memo_fu: Future = Future() + + if data['exception'] is None: + memo_fu.set_result(data['result']) + else: + assert data['result'] is None + memo_fu.set_exception(data['exception']) + + return memo_fu + + def update_memo_result(self, task: TaskRecord, result: Any) -> None: + logger.debug("updating memo") + + if not task['memoize'] or 'hashsum' not in task: + logger.debug("preconditions for memo not satisfied") + return + + if not isinstance(task['hashsum'], str): + logger.error(f"Attempting to update app cache entry but hashsum is not a string key: {task['hashsum']}") + return + + hashsum = task['hashsum'] + + # this comes from the original concatenation-based checkpoint code: + # assert app_fu.done(), "assumption: update_memo is called after future has a result" + t = {'hash': hashsum, 'exception': None, 'result': result} + # else: + # t = {'hash': hashsum, 'exception': app_fu.exception(), 'result': None} + + value = pickle.dumps(t) + + connection = sqlite3.connect(self.db_path) + cursor = connection.cursor() + + cursor.execute("INSERT INTO checkpoints VALUES(?, ?)", (hashsum, value)) + + connection.commit() + connection.close() + + def update_memo_exception(self, task: TaskRecord, exception: BaseException) -> None: + logger.debug("updating memo") + + if not task['memoize'] or 'hashsum' not in task: + logger.debug("preconditions for memo not satisfied") + return + + if not isinstance(task['hashsum'], str): + logger.error(f"Attempting to update app cache entry but hashsum is not a string key: {task['hashsum']}") + return + + hashsum = task['hashsum'] + + # this comes from the original concatenation-based checkpoint code: + # assert app_fu.done(), "assumption: update_memo is called after future has a result" + # t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()} + # else: + t = {'hash': hashsum, 'exception': exception, 'result': None} + + value = pickle.dumps(t) + + connection = sqlite3.connect(self.db_path) + cursor = connection.cursor() + + cursor.execute("INSERT INTO checkpoints VALUES(?, ?)", (hashsum, value)) + + connection.commit() + connection.close() diff --git a/parsl/tests/configs/htex_local_alternate.py b/parsl/tests/configs/htex_local_alternate.py index 39f90c9be6..20a68ce074 100644 --- a/parsl/tests/configs/htex_local_alternate.py +++ b/parsl/tests/configs/htex_local_alternate.py @@ -22,6 +22,7 @@ from parsl.data_provider.ftp import FTPInTaskStaging from parsl.data_provider.http import HTTPInTaskStaging from parsl.data_provider.zip import ZipFileStaging +from parsl.dataflow.memosql import SQLiteMemoizer from parsl.executors import HighThroughputExecutor from parsl.launchers import SingleNodeLauncher @@ -56,14 +57,14 @@ def fresh_config(): ) ], strategy='simple', - app_cache=True, checkpoint_mode='task_exit', retries=2, monitoring=MonitoringHub( monitoring_debug=False, resource_monitoring_interval=1, ), usage_tracking=3, - project_name="parsl htex_local_alternate test configuration" + project_name="parsl htex_local_alternate test configuration", + memoizer=SQLiteMemoizer() ) diff --git a/parsl/tests/configs/local_threads_checkpoint.py b/parsl/tests/configs/local_threads_checkpoint.py deleted file mode 100644 index 4fd33a27cd..0000000000 --- a/parsl/tests/configs/local_threads_checkpoint.py +++ /dev/null @@ -1,15 +0,0 @@ -from parsl.config import Config -from parsl.executors.threads import ThreadPoolExecutor - - -def fresh_config(): - return Config( - executors=[ - ThreadPoolExecutor( - label='local_threads_checkpoint', - ) - ] - ) - - -config = fresh_config() diff --git a/parsl/tests/configs/local_threads_checkpoint_dfk_exit.py b/parsl/tests/configs/local_threads_checkpoint_dfk_exit.py index 3a8cdbdaea..63cf96b570 100644 --- a/parsl/tests/configs/local_threads_checkpoint_dfk_exit.py +++ b/parsl/tests/configs/local_threads_checkpoint_dfk_exit.py @@ -1,4 +1,5 @@ from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer from parsl.executors.threads import ThreadPoolExecutor config = Config( @@ -7,5 +8,5 @@ label='local_threads_checkpoint_dfk_exit', ) ], - checkpoint_mode='dfk_exit' + memoizer=BasicMemoizer(checkpoint_mode='dfk_exit') ) diff --git a/parsl/tests/configs/local_threads_checkpoint_task_exit.py b/parsl/tests/configs/local_threads_checkpoint_task_exit.py index f743908542..f5b36e0670 100644 --- a/parsl/tests/configs/local_threads_checkpoint_task_exit.py +++ b/parsl/tests/configs/local_threads_checkpoint_task_exit.py @@ -1,4 +1,5 @@ from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer from parsl.executors.threads import ThreadPoolExecutor config = Config( @@ -7,5 +8,5 @@ label='local_threads_checkpoint_task_exit', ) ], - checkpoint_mode='task_exit' + memoizer=BasicMemoizer(checkpoint_mode='task_exit') ) diff --git a/parsl/tests/test_checkpointing/test_periodic.py b/parsl/tests/test_checkpointing/test_periodic.py index 19bde0faab..8e7eb2ed59 100644 --- a/parsl/tests/test_checkpointing/test_periodic.py +++ b/parsl/tests/test_checkpointing/test_periodic.py @@ -3,6 +3,7 @@ import parsl from parsl.app.app import python_app from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer from parsl.executors.threads import ThreadPoolExecutor @@ -10,8 +11,10 @@ def fresh_config(): tpe = ThreadPoolExecutor(label='local_threads_checkpoint_periodic', max_threads=1) return Config( executors=[tpe], - checkpoint_mode='periodic', - checkpoint_period='00:00:02' + memoizer=BasicMemoizer( + checkpoint_mode='periodic', + checkpoint_period='00:00:02' + ) ) @@ -32,7 +35,8 @@ def test_periodic(): """Test checkpointing with task_periodic behavior """ with parsl.load(fresh_config()): - h, m, s = map(int, parsl.dfk().config.checkpoint_period.split(":")) + memoizer = parsl.dfk().memoizer + h, m, s = map(int, memoizer.checkpoint_period.split(":")) assert h == 0, "Verify test setup" assert m == 0, "Verify test setup" assert s > 0, "Verify test setup" diff --git a/parsl/tests/test_checkpointing/test_python_checkpoint_1.py b/parsl/tests/test_checkpointing/test_python_checkpoint_1.py index 1d2b38db22..35d88d0518 100644 --- a/parsl/tests/test_checkpointing/test_python_checkpoint_1.py +++ b/parsl/tests/test_checkpointing/test_python_checkpoint_1.py @@ -5,13 +5,8 @@ import parsl from parsl import python_app -from parsl.tests.configs.local_threads import fresh_config - - -def local_config(): - config = fresh_config() - config.checkpoint_mode = "manual" - return config +from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer @python_app(cache=True) @@ -21,14 +16,17 @@ def uuid_app(): @pytest.mark.local -def test_initial_checkpoint_write() -> None: +def test_manual_checkpoint() -> None: """1. Launch a few apps and write the checkpoint once a few have completed """ - uuid_app().result() + memoizer = BasicMemoizer(checkpoint_mode="manual") + + with parsl.load(Config(memoizer=memoizer)): + uuid_app().result() - parsl.dfk().checkpoint() + memoizer.checkpoint() - cpt_dir = Path(parsl.dfk().run_dir) / 'checkpoint' + cpt_dir = Path(parsl.dfk().run_dir) / 'checkpoint' - cptpath = cpt_dir / 'tasks.pkl' - assert os.path.exists(cptpath), f"Tasks checkpoint missing: {cptpath}" + cptpath = cpt_dir / 'tasks.pkl' + assert os.path.exists(cptpath), f"Tasks checkpoint missing: {cptpath}" diff --git a/parsl/tests/test_checkpointing/test_python_checkpoint_2.py b/parsl/tests/test_checkpointing/test_python_checkpoint_2.py index f219b704e4..f310a44974 100644 --- a/parsl/tests/test_checkpointing/test_python_checkpoint_2.py +++ b/parsl/tests/test_checkpointing/test_python_checkpoint_2.py @@ -5,21 +5,21 @@ import parsl from parsl import python_app -from parsl.tests.configs.local_threads_checkpoint import fresh_config +from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer +from parsl.executors.threads import ThreadPoolExecutor -@contextlib.contextmanager -def parsl_configured(run_dir, **kw): - c = fresh_config() - c.run_dir = run_dir - for config_attr, config_val in kw.items(): - setattr(c, config_attr, config_val) - dfk = parsl.load(c) - for ex in dfk.executors.values(): - ex.working_dir = run_dir - yield dfk - - parsl.dfk().cleanup() +def parsl_configured(run_dir, memoizer): + return parsl.load(Config( + run_dir=str(run_dir), + executors=[ + ThreadPoolExecutor( + label='local_threads_checkpoint', + ) + ], + memoizer=memoizer + )) @python_app(cache=True) @@ -32,11 +32,11 @@ def uuid_app(): def test_loading_checkpoint(tmpd_cwd): """Load memoization table from previous checkpoint """ - with parsl_configured(tmpd_cwd, checkpoint_mode="task_exit"): + with parsl_configured(tmpd_cwd, BasicMemoizer(checkpoint_mode="task_exit")): checkpoint_files = [os.path.join(parsl.dfk().run_dir, "checkpoint")] result = uuid_app().result() - with parsl_configured(tmpd_cwd, checkpoint_files=checkpoint_files): + with parsl_configured(tmpd_cwd, BasicMemoizer(checkpoint_files=checkpoint_files)): relaunched = uuid_app().result() assert result == relaunched, "Expected following call to uuid_app to return cached uuid" diff --git a/parsl/tests/test_checkpointing/test_python_checkpoint_2_sqlite.py b/parsl/tests/test_checkpointing/test_python_checkpoint_2_sqlite.py new file mode 100644 index 0000000000..0d220b915d --- /dev/null +++ b/parsl/tests/test_checkpointing/test_python_checkpoint_2_sqlite.py @@ -0,0 +1,35 @@ +import contextlib +import os + +import pytest + +import parsl +from parsl import python_app +from parsl.config import Config +from parsl.dataflow.memosql import SQLiteMemoizer + + +def parsl_configured(run_dir, memoizer): + return parsl.load(Config( + run_dir=str(run_dir), + memoizer=memoizer + )) + + +@python_app(cache=True) +def uuid_app(): + import uuid + return uuid.uuid4() + + +@pytest.mark.local +def test_loading_checkpoint(tmpd_cwd): + """Load memoization table from previous checkpoint + """ + with parsl_configured(tmpd_cwd, SQLiteMemoizer(checkpoint_dir=tmpd_cwd)): + result = uuid_app().result() + + with parsl_configured(tmpd_cwd, SQLiteMemoizer(checkpoint_dir=tmpd_cwd)): + relaunched = uuid_app().result() + + assert result == relaunched, "Expected following call to uuid_app to return cached uuid" diff --git a/parsl/tests/test_checkpointing/test_python_checkpoint_exceptions.py b/parsl/tests/test_checkpointing/test_python_checkpoint_exceptions.py new file mode 100644 index 0000000000..2fe9acfd17 --- /dev/null +++ b/parsl/tests/test_checkpointing/test_python_checkpoint_exceptions.py @@ -0,0 +1,49 @@ +import contextlib +import os + +import pytest + +import parsl +from parsl import python_app +from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer +from parsl.executors.threads import ThreadPoolExecutor + + +class CheckpointExceptionsMemoizer(BasicMemoizer): + def filter_exception_for_checkpoint(self, ex): + # TODO: this used to be the case, but in moving to results-only mode, + # the task record is lost. Maybe it's useful to pass it in? What + # are the use cases for this deciding function? + # task record is available from app_fu.task_record + # assert app_fu.task_record is not None + + # override the default always-False, to be always-True + return True + + +def fresh_config(run_dir, memoizer): + return Config( + memoizer=memoizer, + run_dir=str(run_dir) + ) + + +@python_app(cache=True) +def uuid_app(): + import uuid + raise RuntimeError(str(uuid.uuid4())) + + +@pytest.mark.local +def test_loading_checkpoint(tmpd_cwd): + """Load memoization table from previous checkpoint + """ + with parsl.load(fresh_config(tmpd_cwd, CheckpointExceptionsMemoizer(checkpoint_mode="task_exit"))): + checkpoint_files = [os.path.join(parsl.dfk().run_dir, "checkpoint")] + result = uuid_app().exception() + + with parsl.load(fresh_config(tmpd_cwd, CheckpointExceptionsMemoizer(checkpoint_files=checkpoint_files))): + relaunched = uuid_app().exception() + + assert result.args == relaunched.args, "Expected following call to uuid_app to return cached uuid in exception" diff --git a/parsl/tests/test_checkpointing/test_regression_233.py b/parsl/tests/test_checkpointing/test_regression_233.py index 04ca52bdbd..2abd55abb8 100644 --- a/parsl/tests/test_checkpointing/test_regression_233.py +++ b/parsl/tests/test_checkpointing/test_regression_233.py @@ -1,12 +1,13 @@ import pytest from parsl.app.app import python_app +from parsl.config import Config from parsl.dataflow.dflow import DataFlowKernel +from parsl.dataflow.memoization import BasicMemoizer def run_checkpointed(checkpoints): - from parsl.tests.configs.local_threads_checkpoint_task_exit import config - config.checkpoint_files = checkpoints + config = Config(memoizer=BasicMemoizer(checkpoint_files=checkpoints, checkpoint_mode='task_exit')) dfk = DataFlowKernel(config=config) @python_app(data_flow_kernel=dfk, cache=True) diff --git a/parsl/tests/test_python_apps/test_memoize_2.py b/parsl/tests/test_python_apps/test_memoize_2.py index e884812277..514f2b23fd 100644 --- a/parsl/tests/test_python_apps/test_memoize_2.py +++ b/parsl/tests/test_python_apps/test_memoize_2.py @@ -5,6 +5,7 @@ import parsl from parsl.app.app import python_app from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer from parsl.executors.threads import ThreadPoolExecutor @@ -13,7 +14,7 @@ def local_config(): executors=[ ThreadPoolExecutor(max_threads=4), ], - app_cache=False + memoizer=BasicMemoizer(memoize=False) # TODO: this should be a better do-nothing impl? ) diff --git a/parsl/tests/test_python_apps/test_memoize_exception.py b/parsl/tests/test_python_apps/test_memoize_exception.py index 86aac22f41..28a5368589 100644 --- a/parsl/tests/test_python_apps/test_memoize_exception.py +++ b/parsl/tests/test_python_apps/test_memoize_exception.py @@ -1,5 +1,17 @@ +import pytest + import parsl from parsl.app.app import python_app +from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer +from parsl.executors.threads import ThreadPoolExecutor + + +def local_config(): + return Config( + executors=[ThreadPoolExecutor()], + memoizer=BasicMemoizer() + ) @python_app(cache=True) @@ -12,8 +24,9 @@ def raise_exception_nocache(x, cache=True): raise RuntimeError("exception from raise_exception_nocache") +@pytest.mark.local def test_python_memoization(n=2): - """Test Python memoization of exceptions, with cache=True""" + """Test BasicMemoizer memoization of exceptions, with cache=True""" x = raise_exception_cache(0) # wait for x to be done @@ -27,8 +40,9 @@ def test_python_memoization(n=2): assert fut.exception() is x.exception(), "Memoized exception should have been memoized" +@pytest.mark.local def test_python_no_memoization(n=2): - """Test Python non-memoization of exceptions, with cache=False""" + """Test BasicMemoizer non-memoization of exceptions, with cache=False""" x = raise_exception_nocache(0) # wait for x to be done diff --git a/parsl/tests/test_python_apps/test_memoize_plugin.py b/parsl/tests/test_python_apps/test_memoize_plugin.py new file mode 100644 index 0000000000..724facf165 --- /dev/null +++ b/parsl/tests/test_python_apps/test_memoize_plugin.py @@ -0,0 +1,53 @@ +import argparse + +import pytest + +import parsl +from parsl.app.app import python_app +from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer +from parsl.dataflow.taskrecord import TaskRecord + + +class DontReuseSevenMemoizer(BasicMemoizer): + def check_memo(self, task_record: TaskRecord): + if task_record['args'][0] == 7: + return None # we didn't find a suitable memo record... + else: + return super().check_memo(task_record) + + +def local_config(): + return Config(memoizer=DontReuseSevenMemoizer()) + + +@python_app(cache=True) +def random_uuid(x, cache=True): + import uuid + return str(uuid.uuid4()) + + +@pytest.mark.local +def test_python_memoization(n=10): + """Testing python memoization disable + """ + + # TODO: this .result() needs to be here, not in the loop + # because otherwise we race to complete... and then + # we might sometimes get a memoization before the loop + # and sometimes not... + x = random_uuid(0).result() + + for i in range(0, n): + foo = random_uuid(0) + print(i) + print(foo.result()) + assert foo.result() == x, "Memoized results were incorrectly not used" + + y = random_uuid(7).result() + + for i in range(0, n): + foo = random_uuid(7) + print(i) + print(foo.result()) + assert foo.result() != y, "Memoized results were incorrectly used"