diff --git a/train.py b/train.py index 9e1cf63..159dde9 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,5 @@ import torch +import pathlib import numpy as np import argparse import time @@ -55,7 +56,8 @@ def main(): if args.aptonly: supports = None - + pathlib.Path(args.save).mkdir(parents=True, exist_ok=True) + print(f"Save directory: {args.save} has been created") engine = trainer(scaler, args.in_dim, args.seq_length, args.num_nodes, args.nhid, args.dropout, args.learning_rate, args.weight_decay, device, supports, args.gcn_bool, args.addaptadj,