11from copy import deepcopy
22from typing import Optional , Dict , Tuple , Any
3+ from pathlib import Path
34
45import keepsake
56from pytorch_lightning .callbacks .base import Callback
@@ -56,15 +57,18 @@ def __init__(
5657 """
5758
5859 super ().__init__ ()
59- self .filepath = filepath
60+ self .filepath = Path ( filepath ). resolve ()
6061 self .params = params
6162 self .primary_metric = primary_metric
6263 self .period = period
6364 self .save_weights_only = save_weights_only
6465 self .last_global_step_saved = - 1
6566
6667 def on_pretrain_routine_start (self , trainer , pl_module ):
67- self .experiment = keepsake .init (path = "." , params = self .params )
68+ self .experiment = keepsake .init (
69+ path = str (self .filepath .parent ),
70+ params = self .params ,
71+ )
6872
6973 def on_epoch_end (self , trainer , pl_module ):
7074 self ._save_model (trainer , pl_module )
@@ -89,7 +93,7 @@ def _save_model(self, trainer, pl_module):
8993 return
9094
9195 if self .filepath != None :
92- trainer .save_checkpoint (self .filepath , self .save_weights_only )
96+ trainer .save_checkpoint (self .filepath . name , self .save_weights_only )
9397
9498 self .last_global_step_saved = global_step
9599
@@ -99,7 +103,7 @@ def _save_model(self, trainer, pl_module):
99103 metrics .update ({"global_step" : trainer .global_step })
100104
101105 self .experiment .checkpoint (
102- path = self .filepath ,
106+ path = self .filepath . name ,
103107 step = trainer .current_epoch ,
104108 metrics = metrics ,
105109 primary_metric = self .primary_metric ,
0 commit comments