Skip to content

Commit 169ba3f

Browse files
committed
Added ability to set directory in pytorch lightning
Signed-off-by: Adam Fishman <[email protected]>
1 parent a8032b4 commit 169ba3f

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

python/keepsake/pl_callback.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from copy import deepcopy
22
from typing import Optional, Dict, Tuple, Any
3+
from pathlib import Path
34

45
import keepsake
56
from pytorch_lightning.callbacks.base import Callback
@@ -56,15 +57,18 @@ def __init__(
5657
"""
5758

5859
super().__init__()
59-
self.filepath = filepath
60+
self.filepath = Path(filepath).resolve()
6061
self.params = params
6162
self.primary_metric = primary_metric
6263
self.period = period
6364
self.save_weights_only = save_weights_only
6465
self.last_global_step_saved = -1
6566

6667
def on_pretrain_routine_start(self, trainer, pl_module):
67-
self.experiment = keepsake.init(path=".", params=self.params)
68+
self.experiment = keepsake.init(
69+
path=str(self.filepath.parent),
70+
params=self.params,
71+
)
6872

6973
def on_epoch_end(self, trainer, pl_module):
7074
self._save_model(trainer, pl_module)
@@ -89,7 +93,7 @@ def _save_model(self, trainer, pl_module):
8993
return
9094

9195
if self.filepath != None:
92-
trainer.save_checkpoint(self.filepath, self.save_weights_only)
96+
trainer.save_checkpoint(self.filepath.name, self.save_weights_only)
9397

9498
self.last_global_step_saved = global_step
9599

@@ -99,7 +103,7 @@ def _save_model(self, trainer, pl_module):
99103
metrics.update({"global_step": trainer.global_step})
100104

101105
self.experiment.checkpoint(
102-
path=self.filepath,
106+
path=self.filepath.name,
103107
step=trainer.current_epoch,
104108
metrics=metrics,
105109
primary_metric=self.primary_metric,

0 commit comments

Comments
 (0)