Skip to content

Commit 57fcfc6

Browse files
committed
out-of-memory checkpointing
goal: results should not (never? in weak small cache?) be stored in an in-memory memo table. so that memo table should be not present in this implementation. instead all memo questions go to the sqlite3 database. this drives some blurring between in-memory caching and disk-based checkpointing: the previous disk based checkpointed model relied on repopulating the in-memory memo table cache... i hit some thread problems when using one sqlite3 connection across threads and the docs are unclear about what I can/cannot do, so i made this open the sqlite3 database on every access. that's probably got quite a performance hit, but its probably enough for basically validating the idea.
1 parent 241171b commit 57fcfc6

File tree

4 files changed

+175
-8
lines changed

4 files changed

+175
-8
lines changed

parsl/dataflow/memoization.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pickle
77
import threading
88
from functools import lru_cache, singledispatch
9-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
9+
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
1010

1111
import typeguard
1212

@@ -161,13 +161,13 @@ class Memoizer:
161161
def start(self, *, dfk: DataFlowKernel, memoize: bool = True, checkpoint_files: Sequence[str], run_dir: str) -> None:
162162
raise NotImplementedError
163163

164-
def update_memo(self, task: TaskRecord, r: Future[Any]) -> None:
164+
def update_memo(self, task: TaskRecord, r: Future) -> None:
165165
raise NotImplementedError
166166

167167
def checkpoint(self, tasks: Sequence[TaskRecord]) -> None:
168168
raise NotImplementedError
169169

170-
def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]:
170+
def check_memo(self, task: TaskRecord) -> Optional[Future]:
171171
raise NotImplementedError
172172

173173
def close(self) -> None:
@@ -239,7 +239,10 @@ def start(self, *, dfk: DataFlowKernel, memoize: bool = True, checkpoint_files:
239239
logger.info("App caching disabled for all apps")
240240
self.memo_lookup_table = {}
241241

242-
def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]:
242+
def close(self) -> None:
243+
pass # nothing to close but more should move here
244+
245+
def check_memo(self, task: TaskRecord) -> Optional[Future]:
243246
"""Create a hash of the task and its inputs and check the lookup table for this hash.
244247
245248
If present, the results are returned.
@@ -274,7 +277,7 @@ def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]:
274277
assert isinstance(result, Future) or result is None
275278
return result
276279

277-
def hash_lookup(self, hashsum: str) -> Future[Any]:
280+
def hash_lookup(self, hashsum: str) -> Future:
278281
"""Lookup a hash in the memoization table.
279282
280283
Args:
@@ -288,7 +291,7 @@ def hash_lookup(self, hashsum: str) -> Future[Any]:
288291
"""
289292
return self.memo_lookup_table[hashsum]
290293

291-
def update_memo(self, task: TaskRecord, r: Future[Any]) -> None:
294+
def update_memo(self, task: TaskRecord, r: Future) -> None:
292295
"""Updates the memoization lookup table with the result from a task.
293296
294297
Args:
@@ -313,7 +316,7 @@ def update_memo(self, task: TaskRecord, r: Future[Any]) -> None:
313316
logger.debug(f"Storing app cache entry {task['hashsum']} with result from task {task_id}")
314317
self.memo_lookup_table[task['hashsum']] = r
315318

316-
def _load_checkpoints(self, checkpointDirs: Sequence[str]) -> Dict[str, Future[Any]]:
319+
def _load_checkpoints(self, checkpointDirs: Sequence[str]) -> Dict[str, Future]:
317320
"""Load a checkpoint file into a lookup table.
318321
319322
The data being loaded from the pickle file mostly contains input

parsl/dataflow/memosql.py

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import logging
2+
import pickle
3+
import sqlite3
4+
from concurrent.futures import Future
5+
from pathlib import Path
6+
from typing import Optional, Sequence
7+
8+
from parsl.dataflow.dflow import DataFlowKernel
9+
from parsl.dataflow.memoization import Memoizer, make_hash
10+
from parsl.dataflow.taskrecord import TaskRecord
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class SQLiteMemoizer(Memoizer):
16+
"""Memoize out of memory into an sqlite3 database.
17+
18+
TODO: probably going to need some kind of shutdown now, to close
19+
the sqlite3 connection.
20+
which might also be useful for driving final checkpoints in the
21+
original impl?
22+
"""
23+
24+
def start(self, *, dfk: DataFlowKernel, memoize: bool = True, checkpoint_files: Sequence[str], run_dir: str) -> None:
25+
"""TODO: run_dir is the per-workflow run dir, but we need a broader checkpoint context... one level up
26+
by default... get_all_checkpoints uses "runinfo/" as a relative path for that by default so replicating
27+
that choice would do here. likewise I think for monitoring."""
28+
29+
self.db_path = Path(dfk.config.run_dir) / "checkpoint.sqlite3"
30+
logger.debug("starting with db_path %r", self.db_path)
31+
32+
# TODO: api wart... turning memoization on or off should not be part of the plugin API
33+
self.memoize = memoize
34+
35+
connection = sqlite3.connect(self.db_path)
36+
cursor = connection.cursor()
37+
38+
cursor.execute("CREATE TABLE IF NOT EXISTS checkpoints(key, result)")
39+
# probably want some index on key because that's what we're doing all the access via.
40+
41+
connection.commit()
42+
connection.close()
43+
logger.debug("checkpoint table created")
44+
45+
def close(self):
46+
pass
47+
48+
def checkpoint(self, tasks: Sequence[TaskRecord]) -> None:
49+
"""All the behaviour for this memoizer is in check_memo and update_memo.
50+
"""
51+
logger.debug("Explicit checkpoint call is a no-op with this memoizer")
52+
53+
def check_memo(self, task: TaskRecord) -> Optional[Future]:
54+
"""TODO: document this: check_memo is required to set the task hashsum,
55+
if that's how we're going to key checkpoints in update_memo. (that's not
56+
a requirement though: other equalities are available."""
57+
task_id = task['id']
58+
59+
if not self.memoize or not task['memoize']:
60+
task['hashsum'] = None
61+
logger.debug("Task %s will not be memoized", task_id)
62+
return None
63+
64+
hashsum = make_hash(task)
65+
logger.debug("Task {} has memoization hash {}".format(task_id, hashsum))
66+
task['hashsum'] = hashsum
67+
68+
connection = sqlite3.connect(self.db_path)
69+
cursor = connection.cursor()
70+
cursor.execute("SELECT result FROM checkpoints WHERE key = ?", (hashsum, ))
71+
r = cursor.fetchone()
72+
73+
if r is None:
74+
connection.close()
75+
return None
76+
else:
77+
data = pickle.loads(r[0])
78+
connection.close()
79+
80+
memo_fu: Future = Future()
81+
82+
if data['exception'] is None:
83+
memo_fu.set_result(data['result'])
84+
else:
85+
assert data['result'] is None
86+
memo_fu.set_exception(data['exception'])
87+
88+
return memo_fu
89+
90+
def update_memo(self, task: TaskRecord, r: Future) -> None:
91+
logger.debug("updating memo")
92+
93+
if not self.memoize or not task['memoize'] or 'hashsum' not in task:
94+
logger.debug("preconditions for memo not satisfied")
95+
return
96+
97+
if not isinstance(task['hashsum'], str):
98+
logger.error(f"Attempting to update app cache entry but hashsum is not a string key: {task['hashsum']}")
99+
return
100+
101+
app_fu = task['app_fu']
102+
hashsum = task['hashsum']
103+
104+
# this comes from the original concatenation-based checkpoint code:
105+
if app_fu.exception() is None:
106+
t = {'hash': hashsum, 'exception': None, 'result': app_fu.result()}
107+
else:
108+
t = {'hash': hashsum, 'exception': app_fu.exception(), 'result': None}
109+
110+
value = pickle.dumps(t)
111+
112+
connection = sqlite3.connect(self.db_path)
113+
cursor = connection.cursor()
114+
115+
cursor.execute("INSERT INTO checkpoints VALUES(?, ?)", (hashsum, value))
116+
117+
connection.commit()
118+
connection.close()

parsl/tests/configs/htex_local_alternate.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from parsl.data_provider.ftp import FTPInTaskStaging
2323
from parsl.data_provider.http import HTTPInTaskStaging
2424
from parsl.data_provider.zip import ZipFileStaging
25+
from parsl.dataflow.memosql import SQLiteMemoizer
2526
from parsl.executors import HighThroughputExecutor
2627
from parsl.launchers import SingleNodeLauncher
2728

@@ -64,7 +65,8 @@ def fresh_config():
6465
resource_monitoring_interval=1,
6566
),
6667
usage_tracking=3,
67-
project_name="parsl htex_local_alternate test configuration"
68+
project_name="parsl htex_local_alternate test configuration",
69+
memoizer=SQLiteMemoizer()
6870
)
6971

7072

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import contextlib
2+
import os
3+
4+
import pytest
5+
6+
import parsl
7+
from parsl import python_app
8+
from parsl.dataflow.memosql import SQLiteMemoizer
9+
from parsl.tests.configs.local_threads_checkpoint import fresh_config
10+
11+
12+
@contextlib.contextmanager
13+
def parsl_configured(run_dir, **kw):
14+
c = fresh_config()
15+
c.memoizer = SQLiteMemoizer()
16+
c.run_dir = run_dir
17+
for config_attr, config_val in kw.items():
18+
setattr(c, config_attr, config_val)
19+
dfk = parsl.load(c)
20+
for ex in dfk.executors.values():
21+
ex.working_dir = run_dir
22+
yield dfk
23+
24+
parsl.dfk().cleanup()
25+
26+
27+
@python_app(cache=True)
28+
def uuid_app():
29+
import uuid
30+
return uuid.uuid4()
31+
32+
33+
@pytest.mark.local
34+
def test_loading_checkpoint(tmpd_cwd):
35+
"""Load memoization table from previous checkpoint
36+
"""
37+
with parsl_configured(tmpd_cwd, checkpoint_mode="task_exit"):
38+
checkpoint_files = [os.path.join(parsl.dfk().run_dir, "checkpoint")]
39+
result = uuid_app().result()
40+
41+
with parsl_configured(tmpd_cwd, checkpoint_files=checkpoint_files):
42+
relaunched = uuid_app().result()
43+
44+
assert result == relaunched, "Expected following call to uuid_app to return cached uuid"

0 commit comments

Comments
 (0)