1
1
from collections import namedtuple
2
2
import os
3
+ import sys
3
4
import cv2
4
5
import time
5
6
import math
11
12
from glob import glob
12
13
from PIL import Image
13
14
14
- from comet_ml import Experiment as CometExperiment
15
+ from comet_ml import Experiment as CometExperiment , OfflineExperiment
15
16
import torch
16
17
import torch .nn as nn
17
18
import torch .optim as optim
34
35
35
36
parser = argparse .ArgumentParser ()
36
37
parser .add_argument ("--mode" , default = "run" , help = "run, array, or job" )
38
+ parser .add_argument (
39
+ '--time' ,
40
+ type = float ,
41
+ default = 10 ,
42
+ help = 'the number of hours' ,
43
+ )
37
44
parser .add_argument ("--not_pretrain" , action = "store_true" , default = False )
45
+ parser .add_argument ('--local_comet_dir' ,
46
+ type = str ,
47
+ default = None ,
48
+ help = 'local dir to process comet locally only. '
49
+ 'primarily for fb, will stop remote calls.' )
38
50
parser .add_argument ('--name' ,
39
51
type = str ,
40
52
help = 'the identifying name of this experiment.' ,
@@ -334,31 +346,47 @@ def forward(self, x):
334
346
def main (args ):
335
347
print ('Pretrain? ' , not args .not_pretrain )
336
348
print (args .model )
337
-
338
- comet_exp = CometExperiment (api_key = "hIXq6lDzWzz24zgKv7RYz6blo" ,
339
- project_name = "selfcifar" ,
340
- workspace = "cinjon" ,
341
- auto_metric_logging = True ,
342
- auto_output_logging = None ,
343
- auto_param_logging = False )
349
+ start_time = time .time ()
350
+
351
+ if opt ['local_comet_dir' ]:
352
+ comet_exp = OfflineExperiment (
353
+ api_key = "hIXq6lDzWzz24zgKv7RYz6blo" ,
354
+ project_name = "selfcifar" ,
355
+ workspace = "cinjon" ,
356
+ auto_metric_logging = True ,
357
+ auto_output_logging = None ,
358
+ auto_param_logging = False ,
359
+ offline_directory = opt ['local_comet_dir' ])
360
+ else :
361
+ comet_exp = CometExperiment (
362
+ api_key = "hIXq6lDzWzz24zgKv7RYz6blo" ,
363
+ project_name = "selfcifar" ,
364
+ workspace = "cinjon" ,
365
+ auto_metric_logging = True ,
366
+ auto_output_logging = None ,
367
+ auto_param_logging = False )
344
368
comet_exp .log_parameters (vars (args ))
345
369
comet_exp .set_name (args .name )
346
370
347
371
# Build model
348
- path = "/misc/kcgscratch1/ChoGroup/resnick/spaceofmotion/zeping/bsn"
372
+ # path = "/misc/kcgscratch1/ChoGroup/resnick/spaceofmotion/zeping/bsn"
349
373
if args .model == "amdim" :
350
- hparams = load_hparams_from_tags_csv (os .path .join (path , "meta_tags.csv" ))
374
+ hparams = load_hparams_from_tags_csv ('/checkpoint/cinjon/amdim/meta_tags.csv' )
375
+ # hparams = load_hparams_from_tags_csv(os.path.join(path, "meta_tags.csv"))
351
376
model = AMDIMModel (hparams )
352
377
if not args .not_pretrain :
353
- model .load_state_dict (
354
- torch .load (os .path .join (path , "_ckpt_epoch_434.ckpt" ))["state_dict" ])
378
+ # _path = os.path.join(path, "_ckpt_epoch_434.ckpt")
379
+ _path = '/checkpoint/cinjon/amdim/_ckpt_epoch_434.ckpt'
380
+ model .load_state_dict (torch .load (_path )["state_dict" ])
355
381
else :
356
382
print ("AMDIM not loading checkpoint" ) # Debug
357
383
linear_model = LinearModel (AMDIM_OUTPUT_DIM , args .num_classes )
358
384
elif args .model == "ccc" :
359
385
model = CCCModel (None )
360
386
if not args .not_pretrain :
361
- checkpoint = torch .load (os .path .join (path , "TimeCycleCkpt14.pth" ))
387
+ # _path = os.path.join(path, "TimeCycleCkpt14.pth")
388
+ _path = '/checkpoint/cinjon/spaceofmotion/bsn/TimeCycleCkpt14.pth'
389
+ checkpoint = torch .load (_path )
362
390
base_dict = {
363
391
'.' .join (k .split ('.' )[1 :]): v
364
392
for k , v in list (checkpoint ['state_dict' ].items ())}
@@ -369,7 +397,9 @@ def main(args):
369
397
elif args .model == "corrflow" :
370
398
model = CORRFLOWModel (None )
371
399
if not args .not_pretrain :
372
- checkpoint = torch .load (os .path .join (path , "corrflow.kineticsmodel.pth" ))
400
+ _path = '/checkpoint/cinjon/spaceofmotion/supercons/corrflow.kineticsmodel.pth'
401
+ # _path = os.path.join(path, "corrflow.kineticsmodel.pth")
402
+ checkpoint = torch .load (_path )
373
403
base_dict = {
374
404
'.' .join (k .split ('.' )[1 :]): v
375
405
for k , v in list (checkpoint ['state_dict' ].items ())}
@@ -433,8 +463,7 @@ def main(args):
433
463
batch_size = args .batch_size * torch .cuda .device_count ()
434
464
# CIFAR-10
435
465
if args .num_classes == 10 :
436
- data_path = ("/misc/kcgscratch1/ChoGroup/resnick/spaceofmotion/zeping/"
437
- "bsn/data/cifar-10-batches-py" )
466
+ data_path = ("/private/home/cinjon/cifar-data/cifar-10-batches-py" )
438
467
_train_dataset = CIFAR_dataset (
439
468
glob (os .path .join (data_path , "data*" )),
440
469
args .num_classes ,
@@ -484,8 +513,7 @@ def main(args):
484
513
# val_dev_dataset, shuffle=False, batch_size=batch_size, num_workers=args.num_workers)
485
514
# CIFAR-100
486
515
elif args .num_classes == 100 :
487
- data_path = ("/misc/kcgscratch1/ChoGroup/resnick/spaceofmotion/zeping/"
488
- "bsn/data/cifar-100-python" )
516
+ data_path = ("/private/home/cinjon/cifar-data/cifar-100-python" )
489
517
_train_dataset = CIFAR_dataset (
490
518
[os .path .join (data_path , "train" )],
491
519
args .num_classes ,
@@ -529,6 +557,10 @@ def main(args):
529
557
train_acc = 0
530
558
train_loss_sum = 0.0
531
559
for iter , input in enumerate (train_dataloader ):
560
+ if time .time () - start_time > args .time * 3600 - 10 and comet_exp is not None :
561
+ comet_exp .end ()
562
+ sys .exit (- 1 )
563
+
532
564
imgs = input [0 ].to (device )
533
565
if args .model != "resnet" :
534
566
imgs = imgs .unsqueeze (1 )
@@ -704,8 +736,10 @@ def main(args):
704
736
if val_acc > best_acc :
705
737
best_acc = val_acc
706
738
best_epoch = epoch
707
- save_path = os .path .join (log_dir , "{}.pth" .format (epoch ))
708
- torch .save (linear_model .state_dict (), save_path )
739
+ linear_save_path = os .path .join (log_dir , "{}.linear.pth" .format (epoch ))
740
+ model_save_path = os .path .join (log_dir , "{}.model.pth" .format (epoch ))
741
+ torch .save (linear_model .state_dict (), linear_save_path )
742
+ torch .save (model .state_dict (), model_save_path )
709
743
710
744
# Check bias and variance
711
745
print ("Epoch {} lr {} total: train_loss:{} train_acc:{} val_loss:{} val_acc:{}" .format (
0 commit comments