diff --git a/parsl/benchmark/perf.py b/parsl/benchmark/perf.py index f2183c6b0c..b73e4ffe45 100644 --- a/parsl/benchmark/perf.py +++ b/parsl/benchmark/perf.py @@ -48,7 +48,7 @@ def performance(*, resources: dict, target_t: float, args_extra_size: int, itera iteration = 1 - args_extra_payload = "x" * args_extra_size + # args_extra_payload = "x" * args_extra_size iterate = True @@ -59,7 +59,13 @@ def performance(*, resources: dict, target_t: float, args_extra_size: int, itera fs = [] print("Submitting tasks / invoking apps") - for _ in range(n): + for index in range(n): + # this means there is a different argument for each iteration, + # which will make checkpointing/memo behave differently + # so this could be switchable in parsl-perf dev branch + # args_extra_payload = index # always a new one (except for run repeats) + + args_extra_payload = index % 10 fs.append(app(args_extra_payload, parsl_resource_specification=resources)) submitted_t = time.time() diff --git a/parsl/config.py b/parsl/config.py index 1358e99d28..8cf3e69a6f 100644 --- a/parsl/config.py +++ b/parsl/config.py @@ -5,6 +5,7 @@ from typing_extensions import Literal from parsl.dataflow.dependency_resolvers import DependencyResolver +from parsl.dataflow.memoization import Memoizer from parsl.dataflow.taskrecord import TaskRecord from parsl.errors import ConfigurationError from parsl.executors.base import ParslExecutor @@ -27,17 +28,6 @@ class Config(RepresentationMixin, UsageInformation): executors : sequence of ParslExecutor, optional List (or other iterable) of `ParslExecutor` instances to use for executing tasks. Default is (:class:`~parsl.executors.threads.ThreadPoolExecutor()`,). - app_cache : bool, optional - Enable app caching. Default is True. - checkpoint_files : sequence of str, optional - List of paths to checkpoint files. See :func:`parsl.utils.get_all_checkpoints` and - :func:`parsl.utils.get_last_checkpoint` for helpers. Default is None. - checkpoint_mode : str, optional - Checkpoint mode to use, can be ``'dfk_exit'``, ``'task_exit'``, ``'periodic'`` or ``'manual'``. - If set to `None`, checkpointing will be disabled. Default is None. - checkpoint_period : str, optional - Time interval (in "HH:MM:SS") at which to checkpoint completed tasks. Only has an effect if - ``checkpoint_mode='periodic'``. dependency_resolver: plugin point for custom dependency resolvers. Default: only resolve Futures, using the `SHALLOW_DEPENDENCY_RESOLVER`. exit_mode: str, optional @@ -100,14 +90,7 @@ class Config(RepresentationMixin, UsageInformation): @typeguard.typechecked def __init__(self, executors: Optional[Iterable[ParslExecutor]] = None, - app_cache: bool = True, - checkpoint_files: Optional[Sequence[str]] = None, - checkpoint_mode: Union[None, - Literal['task_exit'], - Literal['periodic'], - Literal['dfk_exit'], - Literal['manual']] = None, - checkpoint_period: Optional[str] = None, + memoizer: Optional[Memoizer] = None, dependency_resolver: Optional[DependencyResolver] = None, exit_mode: Literal['cleanup', 'skip', 'wait'] = 'cleanup', garbage_collect: bool = True, @@ -131,21 +114,7 @@ def __init__(self, self._executors: Sequence[ParslExecutor] = executors self._validate_executors() - self.app_cache = app_cache - self.checkpoint_files = checkpoint_files - self.checkpoint_mode = checkpoint_mode - if checkpoint_period is not None: - if checkpoint_mode is None: - logger.debug('The requested `checkpoint_period={}` will have no effect because `checkpoint_mode=None`'.format( - checkpoint_period) - ) - elif checkpoint_mode != 'periodic': - logger.debug("Requested checkpoint period of {} only has an effect with checkpoint_mode='periodic'".format( - checkpoint_period) - ) - if checkpoint_mode == 'periodic' and checkpoint_period is None: - checkpoint_period = "00:30:00" - self.checkpoint_period = checkpoint_period + self.memoizer = memoizer self.dependency_resolver = dependency_resolver self.exit_mode = exit_mode self.garbage_collect = garbage_collect diff --git a/parsl/configs/ASPIRE1.py b/parsl/configs/ASPIRE1.py index 017e1061d7..84d131f210 100644 --- a/parsl/configs/ASPIRE1.py +++ b/parsl/configs/ASPIRE1.py @@ -1,5 +1,6 @@ from parsl.addresses import address_by_interface from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer from parsl.executors import HighThroughputExecutor from parsl.launchers import MpiRunLauncher from parsl.monitoring.monitoring import MonitoringHub @@ -38,7 +39,6 @@ ), strategy='simple', retries=3, - app_cache=True, - checkpoint_mode='task_exit', + memoizer=BasicMemoizer(checkpoint_mode='task_exit'), usage_tracking=LEVEL_1, ) diff --git a/parsl/dataflow/dflow.py b/parsl/dataflow/dflow.py index 83000b73cc..cc2a112e76 100644 --- a/parsl/dataflow/dflow.py +++ b/parsl/dataflow/dflow.py @@ -29,7 +29,7 @@ from parsl.dataflow.dependency_resolvers import SHALLOW_DEPENDENCY_RESOLVER from parsl.dataflow.errors import DependencyError, JoinError from parsl.dataflow.futures import AppFuture -from parsl.dataflow.memoization import Memoizer +from parsl.dataflow.memoization import BasicMemoizer, Memoizer from parsl.dataflow.rundirs import make_rundir from parsl.dataflow.states import FINAL_FAILURE_STATES, FINAL_STATES, States from parsl.dataflow.taskrecord import TaskRecord @@ -165,12 +165,8 @@ def __init__(self, config: Config) -> None: self.monitoring_radio.send((MessageType.WORKFLOW_INFO, workflow_info)) - self.memoizer = Memoizer(memoize=config.app_cache, - checkpoint_mode=config.checkpoint_mode, - checkpoint_files=config.checkpoint_files, - checkpoint_period=config.checkpoint_period) - self.memoizer.run_dir = self.run_dir - self.memoizer.start() + self.memoizer: Memoizer = config.memoizer if config.memoizer is not None else BasicMemoizer() + self.memoizer.start(run_dir=self.run_dir) # this must be set before executors are added since add_executors calls # job_status_poller.add_executors. @@ -352,14 +348,8 @@ def handle_exec_update(self, task_record: TaskRecord, future: Future) -> None: task_record['fail_cost'] += 1 if isinstance(e, DependencyError): - # was this sending two task log infos? if so would I see the row twice in the monitoring db? - self.update_task_state(task_record, States.dep_fail) logger.info("Task {} failed due to dependency failure so skipping retries".format(task_id)) - task_record['time_returned'] = datetime.datetime.now() - self._send_task_log_info(task_record) - self.memoizer.update_memo(task_record) - with task_record['app_fu']._update_lock: - task_record['app_fu'].set_exception(e) + self._complete_task_exception(task_record, States.dep_fail, e) elif task_record['fail_cost'] <= self._config.retries: @@ -379,12 +369,7 @@ def handle_exec_update(self, task_record: TaskRecord, future: Future) -> None: else: logger.exception("Task {} failed after {} retry attempts".format(task_id, task_record['try_id'])) - self.update_task_state(task_record, States.failed) - task_record['time_returned'] = datetime.datetime.now() - self._send_task_log_info(task_record) - self.memoizer.update_memo(task_record) - with task_record['app_fu']._update_lock: - task_record['app_fu'].set_exception(e) + self._complete_task_exception(task_record, States.failed, e) else: if task_record['from_memo']: @@ -422,13 +407,10 @@ def handle_exec_update(self, task_record: TaskRecord, future: Future) -> None: for inner_future in joinable: inner_future.add_done_callback(partial(self.handle_join_update, task_record)) else: - self.update_task_state(task_record, States.failed) - task_record['time_returned'] = datetime.datetime.now() - self._send_task_log_info(task_record) - self.memoizer.update_memo(task_record) - with task_record['app_fu']._update_lock: - task_record['app_fu'].set_exception( - TypeError(f"join_app body must return a Future or list of Futures, got {joinable} of type {type(joinable)}")) + self._complete_task_exception( + task_record, + States.failed, + TypeError(f"join_app body must return a Future or list of Futures, got {joinable} of type {type(joinable)}")) self._log_std_streams(task_record) @@ -499,12 +481,7 @@ def handle_join_update(self, task_record: TaskRecord, inner_app_future: Optional # no need to update the fail cost because join apps are never # retried - self.update_task_state(task_record, States.failed) - task_record['time_returned'] = datetime.datetime.now() - self.memoizer.update_memo(task_record) - with task_record['app_fu']._update_lock: - task_record['app_fu'].set_exception(e) - self._send_task_log_info(task_record) + self._complete_task_exception(task_record, States.failed, e) else: # all the joinables succeeded, so construct a result: @@ -521,49 +498,47 @@ def handle_join_update(self, task_record: TaskRecord, inner_app_future: Optional self._log_std_streams(task_record) - def handle_app_update(self, task_record: TaskRecord, future: AppFuture) -> None: - """This function is called as a callback when an AppFuture - is in its final state. - - It will trigger post-app processing such as checkpointing. + def _complete_task_result(self, task_record: TaskRecord, new_state: States, result: Any) -> None: + """Set a task into a completed state + """ + assert new_state in FINAL_STATES + assert new_state not in FINAL_FAILURE_STATES + old_state = task_record['status'] - Args: - task_record : Task record - future (Future) : The relevant app future (which should be - consistent with the task structure 'app_fu' entry + self.update_task_state(task_record, new_state) - """ + logger.info(f"Task {task_record['id']} completed ({old_state.name} -> {new_state.name})") + task_record['time_returned'] = datetime.datetime.now() - task_id = task_record['id'] + self.memoizer.update_memo_result(task_record, result) - if not task_record['app_fu'].done(): - logger.error("Internal consistency error: app_fu is not done for task {}".format(task_id)) - if not task_record['app_fu'] == future: - logger.error("Internal consistency error: callback future is not the app_fu in task structure, for task {}".format(task_id)) + self._send_task_log_info(task_record) - self.memoizer.update_checkpoint(task_record) + self.wipe_task(task_record['id']) - self.wipe_task(task_id) - return + with task_record['app_fu']._update_lock: + task_record['app_fu'].set_result(result) - def _complete_task_result(self, task_record: TaskRecord, new_state: States, result: Any) -> None: - """Set a task into a completed state + def _complete_task_exception(self, task_record: TaskRecord, new_state: States, exception: BaseException) -> None: + """Set a task into a failure state """ assert new_state in FINAL_STATES - assert new_state not in FINAL_FAILURE_STATES + assert new_state in FINAL_FAILURE_STATES old_state = task_record['status'] self.update_task_state(task_record, new_state) - logger.info(f"Task {task_record['id']} completed ({old_state.name} -> {new_state.name})") + logger.info(f"Task {task_record['id']} failed ({old_state.name} -> {new_state.name})") task_record['time_returned'] = datetime.datetime.now() - self.memoizer.update_memo(task_record) + self.memoizer.update_memo_exception(task_record, exception) self._send_task_log_info(task_record) + self.wipe_task(task_record['id']) + with task_record['app_fu']._update_lock: - task_record['app_fu'].set_result(result) + task_record['app_fu'].set_exception(exception) def update_task_state(self, task_record: TaskRecord, new_state: States) -> None: """Updates a task record state, and recording an appropriate change @@ -1053,7 +1028,6 @@ def submit(self, task_record['func_name'], waiting_message)) - app_fu.add_done_callback(partial(self.handle_app_update, task_record)) self.update_task_state(task_record, States.pending) logger.debug("Task {} set to pending state with AppFuture: {}".format(task_id, task_record['app_fu'])) @@ -1227,6 +1201,10 @@ def cleanup(self) -> None: # should still see it. logger.info("DFK cleanup complete") + # TODO: this should maybe go away: manual explicit checkponting is + # a property of the (upcoming) BasicMemoizer, not of a memoisation + # plugin in general -- configure a BasicMemoizer separately from the + # DFK and call checkpoint on that... def checkpoint(self) -> None: self.memoizer.checkpoint() diff --git a/parsl/dataflow/memoization.py b/parsl/dataflow/memoization.py index ae7486d239..118b46e7b6 100644 --- a/parsl/dataflow/memoization.py +++ b/parsl/dataflow/memoization.py @@ -14,7 +14,7 @@ from parsl.dataflow.errors import BadCheckpoint from parsl.dataflow.taskrecord import TaskRecord -from parsl.errors import ConfigurationError, InternalConsistencyError +from parsl.errors import ConfigurationError from parsl.utils import Timer, get_all_checkpoints logger = logging.getLogger(__name__) @@ -118,7 +118,63 @@ def id_for_memo_function(f: types.FunctionType, output_ref: bool = False) -> byt return pickle.dumps(["types.FunctionType", f.__name__, f.__module__]) +def make_hash(task: TaskRecord) -> str: + """Create a hash of the task inputs. + + Args: + - task (dict) : Task dictionary from dfk.tasks + + Returns: + - hash (str) : A unique hash string + """ + + t: List[bytes] = [] + + # if kwargs contains an outputs parameter, that parameter is removed + # and normalised differently - with output_ref set to True. + # kwargs listed in ignore_for_cache will also be removed + + filtered_kw = task['kwargs'].copy() + + ignore_list = task['ignore_for_cache'] + + logger.debug("Ignoring these kwargs for checkpointing: %s", ignore_list) + for k in ignore_list: + logger.debug("Ignoring kwarg %s", k) + del filtered_kw[k] + + if 'outputs' in task['kwargs']: + outputs = task['kwargs']['outputs'] + del filtered_kw['outputs'] + t.append(id_for_memo(outputs, output_ref=True)) + + t.extend(map(id_for_memo, (filtered_kw, task['func'], task['args']))) + + x = b''.join(t) + return hashlib.md5(x).hexdigest() + + class Memoizer: + def update_memo_exception(self, task: TaskRecord, e: BaseException) -> None: + raise NotImplementedError + + def update_memo_result(self, task: TaskRecord, r: Any) -> None: + raise NotImplementedError + + def start(self, *, run_dir: str) -> None: + raise NotImplementedError + + def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None: + raise NotImplementedError + + def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]: + raise NotImplementedError + + def close(self) -> None: + raise NotImplementedError + + +class BasicMemoizer(Memoizer): """Memoizer is responsible for ensuring that identical work is not repeated. When a task is repeated, i.e., the same function is called with the same exact arguments, the @@ -152,18 +208,17 @@ class Memoizer: run_dir: str def __init__(self, *, - memoize: bool = True, - checkpoint_files: Sequence[str] | None, - checkpoint_period: Optional[str], - checkpoint_mode: Literal['task_exit', 'periodic', 'dfk_exit', 'manual'] | None): + checkpoint_files: Sequence[str] | None = None, + checkpoint_period: Optional[str] = None, + checkpoint_mode: Literal['task_exit', 'periodic', 'dfk_exit', 'manual'] | None = None, + memoize: bool = True): # TODO: unlikely to need to set this to false, but it was in config API before... """Initialize the memoizer. KWargs: - memoize (Bool): enable memoization or not. - checkpoint (Dict): A checkpoint loaded as a dict. + TODO: update """ - self.memoize = memoize - self.checkpointed_tasks = 0 self.checkpoint_lock = threading.Lock() @@ -175,8 +230,12 @@ def __init__(self, *, self.checkpointable_tasks: List[TaskRecord] = [] self._checkpoint_timer: Timer | None = None + self.memoize = memoize + + def start(self, *, run_dir: str) -> None: + + self.run_dir = run_dir - def start(self) -> None: if self.checkpoint_files is not None: checkpoint_files = self.checkpoint_files elif self.checkpoint_files is None and self.checkpoint_mode is not None: @@ -213,41 +272,6 @@ def close(self) -> None: logger.info("Stopping checkpoint timer") self._checkpoint_timer.close() - def make_hash(self, task: TaskRecord) -> str: - """Create a hash of the task inputs. - - Args: - - task (dict) : Task dictionary from dfk.tasks - - Returns: - - hash (str) : A unique hash string - """ - - t: List[bytes] = [] - - # if kwargs contains an outputs parameter, that parameter is removed - # and normalised differently - with output_ref set to True. - # kwargs listed in ignore_for_cache will also be removed - - filtered_kw = task['kwargs'].copy() - - ignore_list = task['ignore_for_cache'] - - logger.debug("Ignoring these kwargs for checkpointing: %s", ignore_list) - for k in ignore_list: - logger.debug("Ignoring kwarg %s", k) - del filtered_kw[k] - - if 'outputs' in task['kwargs']: - outputs = task['kwargs']['outputs'] - del filtered_kw['outputs'] - t.append(id_for_memo(outputs, output_ref=True)) - - t.extend(map(id_for_memo, (filtered_kw, task['func'], task['args']))) - - x = b''.join(t) - return hashlib.md5(x).hexdigest() - def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]: """Create a hash of the task and its inputs and check the lookup table for this hash. @@ -269,7 +293,7 @@ def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]: logger.debug("Task {} will not be memoized".format(task_id)) return None - hashsum = self.make_hash(task) + hashsum = make_hash(task) logger.debug("Task {} has memoization hash {}".format(task_id, hashsum)) result = None if hashsum in self.memo_lookup_table: @@ -283,7 +307,33 @@ def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]: assert isinstance(result, Future) or result is None return result - def update_memo(self, task: TaskRecord) -> None: + def update_memo_result(self, task: TaskRecord, r: Any) -> None: + self._update_memo(task) + + if self.checkpoint_mode == 'task_exit': + self.checkpoint(task=task, result=r) + elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'): + # with self._modify_checkpointable_tasks_lock: # TODO: sort out use of this lock + self.checkpointable_tasks.append(task) + elif self.checkpoint_mode is None: + pass + else: + assert False, "Invalid checkpoint mode {self.checkpoint_mode} - should have been validated at initialization" + + def update_memo_exception(self, task: TaskRecord, e: BaseException) -> None: + self._update_memo(task) + + if self.checkpoint_mode == 'task_exit': + self.checkpoint(task=task, exception=e) + elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'): + # with self._modify_checkpointable_tasks_lock: # TODO: sort out use of this lock + self.checkpointable_tasks.append(task) + elif self.checkpoint_mode is None: + pass + else: + assert False, "Invalid checkpoint mode {self.checkpoint_mode} - should have been validated at initialization" + + def _update_memo(self, task: TaskRecord) -> None: """Updates the memoization lookup table with the result from a task. This doesn't move any values around but associates the memoization hashsum with the completed (by success or failure) AppFuture. @@ -334,8 +384,12 @@ def _load_checkpoints(self, checkpointDirs: Sequence[str]) -> Dict[str, Future[A data = pickle.load(f) # Copy and hash only the input attributes memo_fu: Future = Future() - assert data['exception'] is None - memo_fu.set_result(data['result']) + + if data['exception'] is None: + memo_fu.set_result(data['result']) + else: + assert data['result'] is None + memo_fu.set_exception(data['exception']) memo_lookup_table[data['hash']] = memo_fu except EOFError: @@ -374,18 +428,17 @@ def load_checkpoints(self, checkpointDirs: Optional[Sequence[str]]) -> Dict[str, else: return {} - def update_checkpoint(self, task_record: TaskRecord) -> None: - if self.checkpoint_mode == 'task_exit': - self.checkpoint(task=task_record) - elif self.checkpoint_mode in ('manual', 'periodic', 'dfk_exit'): - with self.checkpoint_lock: - self.checkpointable_tasks.append(task_record) - elif self.checkpoint_mode is None: - pass - else: - raise InternalConsistencyError(f"Invalid checkpoint mode {self.checkpoint_mode}") - - def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None: + # TODO: this call becomes even more multiplexed... + # called with no parameters, we write out the task + # called with a task record, we can now no longer expect to get the + # result from the task record future, because it will not be + # populated yet. + # so then either we can an exception, or if exception is None, then + # checkpoint result. it's possible that result can be None as a + # real result: in the case that exception is None. + # what a horrible API that needs refactoring... + + def checkpoint(self, *, task: Optional[TaskRecord] = None, exception: Optional[BaseException] = None, result: Any = None) -> None: """Checkpoint the dfk incrementally to a checkpoint file. When called with no argument, all tasks registered in self.checkpointable_tasks @@ -405,11 +458,6 @@ def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None: """ with self.checkpoint_lock: - if task: - checkpoint_queue = [task] - else: - checkpoint_queue = self.checkpointable_tasks - checkpoint_dir = '{0}/checkpoint'.format(self.run_dir) checkpoint_tasks = checkpoint_dir + '/tasks.pkl' @@ -419,22 +467,53 @@ def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None: count = 0 with open(checkpoint_tasks, 'ab') as f: - for task_record in checkpoint_queue: - task_id = task_record['id'] - app_fu = task_record['app_fu'] + if task: + # TODO: refactor with below + + task_id = task['id'] + hashsum = task['hashsum'] + if not hashsum: + pass # TODO: log an error? see below discussion + else: + if exception is None and self.filter_result_for_checkpoint(result): + t = {'hash': hashsum, 'exception': None, 'result': result} + pickle.dump(t, f) + count += 1 + logger.debug("Task {} checkpointed result".format(task_id)) + elif exception is not None and self.filter_exception_for_checkpoint(exception): + t = {'hash': hashsum, 'exception': exception, 'result': None} + pickle.dump(t, f) + count += 1 + logger.debug("Task {} checkpointed exception".format(task_id)) + else: + pass # no checkpoint - maybe debug log? TODO + else: + checkpoint_queue = self.checkpointable_tasks + + for task_record in checkpoint_queue: + task_id = task_record['id'] + + app_fu = task_record['app_fu'] + + assert app_fu.done(), "trying to checkpoint a task that is not done" - if app_fu.done() and app_fu.exception() is None: hashsum = task_record['hashsum'] if not hashsum: - continue - t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()} - - # We are using pickle here since pickle dumps to a file in 'ab' - # mode behave like a incremental log. - pickle.dump(t, f) - count += 1 - logger.debug("Task {} checkpointed".format(task_id)) + continue # TODO: log an error? maybe some tasks don't have hashsums legitimately? + + if app_fu.exception() is None and self.filter_result_for_checkpoint(app_fu.result()): + t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()} + pickle.dump(t, f) + count += 1 + logger.debug("Task {} checkpointed result".format(task_id)) + elif (e := app_fu.exception()) is not None and self.filter_exception_for_checkpoint(e): + t = {'hash': hashsum, 'exception': app_fu.exception(), 'result': None} + pickle.dump(t, f) + count += 1 + logger.debug("Task {} checkpointed exception".format(task_id)) + else: + pass # TODO: maybe log at debug level self.checkpointed_tasks += count @@ -448,3 +527,11 @@ def checkpoint(self, *, task: Optional[TaskRecord] = None) -> None: if not task: self.checkpointable_tasks = [] + + def filter_result_for_checkpoint(self, result: Any) -> bool: + """Overridable method to decide if an task that ended with a successful result should be checkpointed""" + return True + + def filter_exception_for_checkpoint(self, exception: BaseException) -> bool: + """Overridable method to decide if an entry that ended with an exception should be checkpointed""" + return False diff --git a/parsl/dataflow/memosql.py b/parsl/dataflow/memosql.py new file mode 100644 index 0000000000..cbf1c661db --- /dev/null +++ b/parsl/dataflow/memosql.py @@ -0,0 +1,145 @@ +import logging +import pickle +import sqlite3 +from concurrent.futures import Future +from pathlib import Path +from typing import Any, Optional + +from parsl.dataflow.memoization import Memoizer, make_hash +from parsl.dataflow.taskrecord import TaskRecord + +logger = logging.getLogger(__name__) + + +class SQLiteMemoizer(Memoizer): + """Memoize out of memory into an sqlite3 database. + """ + + def __init__(self, *, checkpoint_dir: str | None = None): + self.checkpoint_dir = checkpoint_dir + + def start(self, *, run_dir: str) -> None: + """TODO: run_dir is the per-workflow run dir, but we need a broader checkpoint context... one level up + by default... get_all_checkpoints uses "runinfo/" as a relative path for that by default so replicating + that choice would do here. likewise I think for monitoring.""" + + self.run_dir = run_dir + + dir = self.checkpoint_dir if self.checkpoint_dir is not None else self.run_dir + + self.db_path = Path(dir) / "checkpoint.sqlite3" + logger.debug("starting with db_path %r", self.db_path) + + connection = sqlite3.connect(self.db_path) + cursor = connection.cursor() + + cursor.execute("CREATE TABLE IF NOT EXISTS checkpoints(key, result)") + # probably want some index on key because that's what we're doing all the access via. + + connection.commit() + connection.close() + logger.debug("checkpoint table created") + + def close(self): + """TODO: probably going to need some kind of shutdown now, to close the sqlite3 connection.""" + pass + + def checkpoint(self, *, task: TaskRecord | None = None) -> None: + """All the behaviour for this memoizer is in check_memo and update_memo. + """ + logger.debug("Explicit checkpoint call is a no-op with this memoizer") + + def check_memo(self, task: TaskRecord) -> Optional[Future]: + """TODO: document this: check_memo is required to set the task hashsum, + if that's how we're going to key checkpoints in update_memo. (that's not + a requirement though: other equalities are available.""" + task_id = task['id'] + + if not task['memoize']: + task['hashsum'] = None + logger.debug("Task %s will not be memoized", task_id) + return None + + hashsum = make_hash(task) + logger.debug("Task {} has memoization hash {}".format(task_id, hashsum)) + task['hashsum'] = hashsum + + connection = sqlite3.connect(self.db_path) + cursor = connection.cursor() + cursor.execute("SELECT result FROM checkpoints WHERE key = ?", (hashsum, )) + r = cursor.fetchone() + + if r is None: + connection.close() + return None + else: + data = pickle.loads(r[0]) + connection.close() + + memo_fu: Future = Future() + + if data['exception'] is None: + memo_fu.set_result(data['result']) + else: + assert data['result'] is None + memo_fu.set_exception(data['exception']) + + return memo_fu + + def update_memo_result(self, task: TaskRecord, result: Any) -> None: + logger.debug("updating memo") + + if not task['memoize'] or 'hashsum' not in task: + logger.debug("preconditions for memo not satisfied") + return + + if not isinstance(task['hashsum'], str): + logger.error(f"Attempting to update app cache entry but hashsum is not a string key: {task['hashsum']}") + return + + hashsum = task['hashsum'] + + # this comes from the original concatenation-based checkpoint code: + # assert app_fu.done(), "assumption: update_memo is called after future has a result" + t = {'hash': hashsum, 'exception': None, 'result': result} + # else: + # t = {'hash': hashsum, 'exception': app_fu.exception(), 'result': None} + + value = pickle.dumps(t) + + connection = sqlite3.connect(self.db_path) + cursor = connection.cursor() + + cursor.execute("INSERT INTO checkpoints VALUES(?, ?)", (hashsum, value)) + + connection.commit() + connection.close() + + def update_memo_exception(self, task: TaskRecord, exception: BaseException) -> None: + logger.debug("updating memo") + + if not task['memoize'] or 'hashsum' not in task: + logger.debug("preconditions for memo not satisfied") + return + + if not isinstance(task['hashsum'], str): + logger.error(f"Attempting to update app cache entry but hashsum is not a string key: {task['hashsum']}") + return + + hashsum = task['hashsum'] + + # this comes from the original concatenation-based checkpoint code: + # assert app_fu.done(), "assumption: update_memo is called after future has a result" + # t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()} + # else: + t = {'hash': hashsum, 'exception': exception, 'result': None} + + value = pickle.dumps(t) + + connection = sqlite3.connect(self.db_path) + cursor = connection.cursor() + + cursor.execute("INSERT INTO checkpoints VALUES(?, ?)", (hashsum, value)) + + connection.commit() + connection.close() diff --git a/parsl/tests/configs/htex_local_alternate.py b/parsl/tests/configs/htex_local_alternate.py index 39f90c9be6..20a68ce074 100644 --- a/parsl/tests/configs/htex_local_alternate.py +++ b/parsl/tests/configs/htex_local_alternate.py @@ -22,6 +22,7 @@ from parsl.data_provider.ftp import FTPInTaskStaging from parsl.data_provider.http import HTTPInTaskStaging from parsl.data_provider.zip import ZipFileStaging +from parsl.dataflow.memosql import SQLiteMemoizer from parsl.executors import HighThroughputExecutor from parsl.launchers import SingleNodeLauncher @@ -56,14 +57,14 @@ def fresh_config(): ) ], strategy='simple', - app_cache=True, checkpoint_mode='task_exit', retries=2, monitoring=MonitoringHub( monitoring_debug=False, resource_monitoring_interval=1, ), usage_tracking=3, - project_name="parsl htex_local_alternate test configuration" + project_name="parsl htex_local_alternate test configuration", + memoizer=SQLiteMemoizer() ) diff --git a/parsl/tests/configs/local_threads_checkpoint.py b/parsl/tests/configs/local_threads_checkpoint.py deleted file mode 100644 index 4fd33a27cd..0000000000 --- a/parsl/tests/configs/local_threads_checkpoint.py +++ /dev/null @@ -1,15 +0,0 @@ -from parsl.config import Config -from parsl.executors.threads import ThreadPoolExecutor - - -def fresh_config(): - return Config( - executors=[ - ThreadPoolExecutor( - label='local_threads_checkpoint', - ) - ] - ) - - -config = fresh_config() diff --git a/parsl/tests/configs/local_threads_checkpoint_dfk_exit.py b/parsl/tests/configs/local_threads_checkpoint_dfk_exit.py index 3a8cdbdaea..63cf96b570 100644 --- a/parsl/tests/configs/local_threads_checkpoint_dfk_exit.py +++ b/parsl/tests/configs/local_threads_checkpoint_dfk_exit.py @@ -1,4 +1,5 @@ from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer from parsl.executors.threads import ThreadPoolExecutor config = Config( @@ -7,5 +8,5 @@ label='local_threads_checkpoint_dfk_exit', ) ], - checkpoint_mode='dfk_exit' + memoizer=BasicMemoizer(checkpoint_mode='dfk_exit') ) diff --git a/parsl/tests/configs/local_threads_checkpoint_task_exit.py b/parsl/tests/configs/local_threads_checkpoint_task_exit.py index f743908542..f5b36e0670 100644 --- a/parsl/tests/configs/local_threads_checkpoint_task_exit.py +++ b/parsl/tests/configs/local_threads_checkpoint_task_exit.py @@ -1,4 +1,5 @@ from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer from parsl.executors.threads import ThreadPoolExecutor config = Config( @@ -7,5 +8,5 @@ label='local_threads_checkpoint_task_exit', ) ], - checkpoint_mode='task_exit' + memoizer=BasicMemoizer(checkpoint_mode='task_exit') ) diff --git a/parsl/tests/test_checkpointing/test_periodic.py b/parsl/tests/test_checkpointing/test_periodic.py index 19bde0faab..8e7eb2ed59 100644 --- a/parsl/tests/test_checkpointing/test_periodic.py +++ b/parsl/tests/test_checkpointing/test_periodic.py @@ -3,6 +3,7 @@ import parsl from parsl.app.app import python_app from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer from parsl.executors.threads import ThreadPoolExecutor @@ -10,8 +11,10 @@ def fresh_config(): tpe = ThreadPoolExecutor(label='local_threads_checkpoint_periodic', max_threads=1) return Config( executors=[tpe], - checkpoint_mode='periodic', - checkpoint_period='00:00:02' + memoizer=BasicMemoizer( + checkpoint_mode='periodic', + checkpoint_period='00:00:02' + ) ) @@ -32,7 +35,8 @@ def test_periodic(): """Test checkpointing with task_periodic behavior """ with parsl.load(fresh_config()): - h, m, s = map(int, parsl.dfk().config.checkpoint_period.split(":")) + memoizer = parsl.dfk().memoizer + h, m, s = map(int, memoizer.checkpoint_period.split(":")) assert h == 0, "Verify test setup" assert m == 0, "Verify test setup" assert s > 0, "Verify test setup" diff --git a/parsl/tests/test_checkpointing/test_python_checkpoint_2.py b/parsl/tests/test_checkpointing/test_python_checkpoint_2.py index f219b704e4..f310a44974 100644 --- a/parsl/tests/test_checkpointing/test_python_checkpoint_2.py +++ b/parsl/tests/test_checkpointing/test_python_checkpoint_2.py @@ -5,21 +5,21 @@ import parsl from parsl import python_app -from parsl.tests.configs.local_threads_checkpoint import fresh_config +from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer +from parsl.executors.threads import ThreadPoolExecutor -@contextlib.contextmanager -def parsl_configured(run_dir, **kw): - c = fresh_config() - c.run_dir = run_dir - for config_attr, config_val in kw.items(): - setattr(c, config_attr, config_val) - dfk = parsl.load(c) - for ex in dfk.executors.values(): - ex.working_dir = run_dir - yield dfk - - parsl.dfk().cleanup() +def parsl_configured(run_dir, memoizer): + return parsl.load(Config( + run_dir=str(run_dir), + executors=[ + ThreadPoolExecutor( + label='local_threads_checkpoint', + ) + ], + memoizer=memoizer + )) @python_app(cache=True) @@ -32,11 +32,11 @@ def uuid_app(): def test_loading_checkpoint(tmpd_cwd): """Load memoization table from previous checkpoint """ - with parsl_configured(tmpd_cwd, checkpoint_mode="task_exit"): + with parsl_configured(tmpd_cwd, BasicMemoizer(checkpoint_mode="task_exit")): checkpoint_files = [os.path.join(parsl.dfk().run_dir, "checkpoint")] result = uuid_app().result() - with parsl_configured(tmpd_cwd, checkpoint_files=checkpoint_files): + with parsl_configured(tmpd_cwd, BasicMemoizer(checkpoint_files=checkpoint_files)): relaunched = uuid_app().result() assert result == relaunched, "Expected following call to uuid_app to return cached uuid" diff --git a/parsl/tests/test_checkpointing/test_python_checkpoint_2_sqlite.py b/parsl/tests/test_checkpointing/test_python_checkpoint_2_sqlite.py new file mode 100644 index 0000000000..0d220b915d --- /dev/null +++ b/parsl/tests/test_checkpointing/test_python_checkpoint_2_sqlite.py @@ -0,0 +1,35 @@ +import contextlib +import os + +import pytest + +import parsl +from parsl import python_app +from parsl.config import Config +from parsl.dataflow.memosql import SQLiteMemoizer + + +def parsl_configured(run_dir, memoizer): + return parsl.load(Config( + run_dir=str(run_dir), + memoizer=memoizer + )) + + +@python_app(cache=True) +def uuid_app(): + import uuid + return uuid.uuid4() + + +@pytest.mark.local +def test_loading_checkpoint(tmpd_cwd): + """Load memoization table from previous checkpoint + """ + with parsl_configured(tmpd_cwd, SQLiteMemoizer(checkpoint_dir=tmpd_cwd)): + result = uuid_app().result() + + with parsl_configured(tmpd_cwd, SQLiteMemoizer(checkpoint_dir=tmpd_cwd)): + relaunched = uuid_app().result() + + assert result == relaunched, "Expected following call to uuid_app to return cached uuid" diff --git a/parsl/tests/test_checkpointing/test_python_checkpoint_exceptions.py b/parsl/tests/test_checkpointing/test_python_checkpoint_exceptions.py new file mode 100644 index 0000000000..2fe9acfd17 --- /dev/null +++ b/parsl/tests/test_checkpointing/test_python_checkpoint_exceptions.py @@ -0,0 +1,49 @@ +import contextlib +import os + +import pytest + +import parsl +from parsl import python_app +from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer +from parsl.executors.threads import ThreadPoolExecutor + + +class CheckpointExceptionsMemoizer(BasicMemoizer): + def filter_exception_for_checkpoint(self, ex): + # TODO: this used to be the case, but in moving to results-only mode, + # the task record is lost. Maybe it's useful to pass it in? What + # are the use cases for this deciding function? + # task record is available from app_fu.task_record + # assert app_fu.task_record is not None + + # override the default always-False, to be always-True + return True + + +def fresh_config(run_dir, memoizer): + return Config( + memoizer=memoizer, + run_dir=str(run_dir) + ) + + +@python_app(cache=True) +def uuid_app(): + import uuid + raise RuntimeError(str(uuid.uuid4())) + + +@pytest.mark.local +def test_loading_checkpoint(tmpd_cwd): + """Load memoization table from previous checkpoint + """ + with parsl.load(fresh_config(tmpd_cwd, CheckpointExceptionsMemoizer(checkpoint_mode="task_exit"))): + checkpoint_files = [os.path.join(parsl.dfk().run_dir, "checkpoint")] + result = uuid_app().exception() + + with parsl.load(fresh_config(tmpd_cwd, CheckpointExceptionsMemoizer(checkpoint_files=checkpoint_files))): + relaunched = uuid_app().exception() + + assert result.args == relaunched.args, "Expected following call to uuid_app to return cached uuid in exception" diff --git a/parsl/tests/test_checkpointing/test_regression_233.py b/parsl/tests/test_checkpointing/test_regression_233.py index 04ca52bdbd..2abd55abb8 100644 --- a/parsl/tests/test_checkpointing/test_regression_233.py +++ b/parsl/tests/test_checkpointing/test_regression_233.py @@ -1,12 +1,13 @@ import pytest from parsl.app.app import python_app +from parsl.config import Config from parsl.dataflow.dflow import DataFlowKernel +from parsl.dataflow.memoization import BasicMemoizer def run_checkpointed(checkpoints): - from parsl.tests.configs.local_threads_checkpoint_task_exit import config - config.checkpoint_files = checkpoints + config = Config(memoizer=BasicMemoizer(checkpoint_files=checkpoints, checkpoint_mode='task_exit')) dfk = DataFlowKernel(config=config) @python_app(data_flow_kernel=dfk, cache=True) diff --git a/parsl/tests/test_python_apps/test_memoize_2.py b/parsl/tests/test_python_apps/test_memoize_2.py index e884812277..514f2b23fd 100644 --- a/parsl/tests/test_python_apps/test_memoize_2.py +++ b/parsl/tests/test_python_apps/test_memoize_2.py @@ -5,6 +5,7 @@ import parsl from parsl.app.app import python_app from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer from parsl.executors.threads import ThreadPoolExecutor @@ -13,7 +14,7 @@ def local_config(): executors=[ ThreadPoolExecutor(max_threads=4), ], - app_cache=False + memoizer=BasicMemoizer(memoize=False) # TODO: this should be a better do-nothing impl? ) diff --git a/parsl/tests/test_python_apps/test_memoize_exception.py b/parsl/tests/test_python_apps/test_memoize_exception.py index 86aac22f41..28a5368589 100644 --- a/parsl/tests/test_python_apps/test_memoize_exception.py +++ b/parsl/tests/test_python_apps/test_memoize_exception.py @@ -1,5 +1,17 @@ +import pytest + import parsl from parsl.app.app import python_app +from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer +from parsl.executors.threads import ThreadPoolExecutor + + +def local_config(): + return Config( + executors=[ThreadPoolExecutor()], + memoizer=BasicMemoizer() + ) @python_app(cache=True) @@ -12,8 +24,9 @@ def raise_exception_nocache(x, cache=True): raise RuntimeError("exception from raise_exception_nocache") +@pytest.mark.local def test_python_memoization(n=2): - """Test Python memoization of exceptions, with cache=True""" + """Test BasicMemoizer memoization of exceptions, with cache=True""" x = raise_exception_cache(0) # wait for x to be done @@ -27,8 +40,9 @@ def test_python_memoization(n=2): assert fut.exception() is x.exception(), "Memoized exception should have been memoized" +@pytest.mark.local def test_python_no_memoization(n=2): - """Test Python non-memoization of exceptions, with cache=False""" + """Test BasicMemoizer non-memoization of exceptions, with cache=False""" x = raise_exception_nocache(0) # wait for x to be done diff --git a/parsl/tests/test_python_apps/test_memoize_plugin.py b/parsl/tests/test_python_apps/test_memoize_plugin.py new file mode 100644 index 0000000000..724facf165 --- /dev/null +++ b/parsl/tests/test_python_apps/test_memoize_plugin.py @@ -0,0 +1,53 @@ +import argparse + +import pytest + +import parsl +from parsl.app.app import python_app +from parsl.config import Config +from parsl.dataflow.memoization import BasicMemoizer +from parsl.dataflow.taskrecord import TaskRecord + + +class DontReuseSevenMemoizer(BasicMemoizer): + def check_memo(self, task_record: TaskRecord): + if task_record['args'][0] == 7: + return None # we didn't find a suitable memo record... + else: + return super().check_memo(task_record) + + +def local_config(): + return Config(memoizer=DontReuseSevenMemoizer()) + + +@python_app(cache=True) +def random_uuid(x, cache=True): + import uuid + return str(uuid.uuid4()) + + +@pytest.mark.local +def test_python_memoization(n=10): + """Testing python memoization disable + """ + + # TODO: this .result() needs to be here, not in the loop + # because otherwise we race to complete... and then + # we might sometimes get a memoization before the loop + # and sometimes not... + x = random_uuid(0).result() + + for i in range(0, n): + foo = random_uuid(0) + print(i) + print(foo.result()) + assert foo.result() == x, "Memoized results were incorrectly not used" + + y = random_uuid(7).result() + + for i in range(0, n): + foo = random_uuid(7) + print(i) + print(foo.result()) + assert foo.result() != y, "Memoized results were incorrectly used"