35
35
36
36
parser = argparse .ArgumentParser ()
37
37
parser .add_argument ("--mode" , default = "run" , help = "run, array, or job" )
38
+ parser .add_argument ("--do_nonlinear" , action = "store_true" , default = False )
38
39
parser .add_argument (
39
40
'--time' ,
40
41
type = float ,
@@ -339,6 +340,42 @@ def forward(self, x):
339
340
# x = self.fc2(x)
340
341
return x
341
342
343
+
344
+ class NonLinearModel (nn .Module ):
345
+ def __init__ (self , in_channels , num_classes , num_layers = 3 ):
346
+ super (LinearModel , self ).__init__ ()
347
+ self .in_channels = in_channels
348
+ self .num_classes = num_classes
349
+ self .num_layers = num_layers
350
+
351
+ self .nonlinear = nn .Sequential ([
352
+ nn .Sequential ([
353
+ nn .Linear (self .in_channels , self .in_channels ),
354
+ nn .ReLU ()
355
+ ])
356
+ for _ in num_layers
357
+ ])
358
+ self .linear = nn .Linear (self .in_channels , self .num_classes )
359
+ self .linear .weight .data .normal_ (0 , 0.01 )
360
+ self .linear .bias .data .zero_ ()
361
+
362
+ # self.fc1 = nn.Linear(self.in_channels, 512)
363
+ # self.fc2 = nn.Linear(512, self.num_classes)
364
+ # self.relu = nn.ReLU(inplace=True)
365
+ # self.fc1.weight.data.normal_(0, 0.01)
366
+ # self.fc1.bias.data.zero_()
367
+ # self.fc2.weight.data.normal_(0, 0.01)
368
+ # self.fc2.bias.data.zero_()
369
+
370
+
371
+ def forward (self , x ):
372
+ x = x .view (x .size (0 ), - 1 )
373
+ x = self .nonlinear (x )
374
+ x = self .linear (x )
375
+ # x = self.relu(self.fc1(x))
376
+ # x = self.fc2(x)
377
+ return x
378
+
342
379
#######################################################
343
380
# Main
344
381
#######################################################
@@ -370,6 +407,8 @@ def main(args):
370
407
371
408
# Build model
372
409
# path = "/misc/kcgscratch1/ChoGroup/resnick/spaceofmotion/zeping/bsn"
410
+ linear_cls = NonLinearModel if args .do_nonlinear else LinearModel
411
+
373
412
if args .model == "amdim" :
374
413
hparams = load_hparams_from_tags_csv ('/checkpoint/cinjon/amdim/meta_tags.csv' )
375
414
# hparams = load_hparams_from_tags_csv(os.path.join(path, "meta_tags.csv"))
@@ -380,7 +419,7 @@ def main(args):
380
419
model .load_state_dict (torch .load (_path )["state_dict" ])
381
420
else :
382
421
print ("AMDIM not loading checkpoint" ) # Debug
383
- linear_model = LinearModel (AMDIM_OUTPUT_DIM , args .num_classes )
422
+ linear_model = linear_cls (AMDIM_OUTPUT_DIM , args .num_classes )
384
423
elif args .model == "ccc" :
385
424
model = CCCModel (None )
386
425
if not args .not_pretrain :
@@ -393,7 +432,7 @@ def main(args):
393
432
model .load_state_dict (base_dict )
394
433
else :
395
434
print ("CCC not loading checkpoint" ) # Debug
396
- linear_model = LinearModel (CCC_OUTPUT_DIM , args .num_classes ).to (device )
435
+ linear_model = linaer_cls (CCC_OUTPUT_DIM , args .num_classes ) # .to(device)
397
436
elif args .model == "corrflow" :
398
437
model = CORRFLOWModel (None )
399
438
if not args .not_pretrain :
@@ -406,7 +445,7 @@ def main(args):
406
445
model .load_state_dict (base_dict )
407
446
else :
408
447
print ("CorrFlow not loading checkpoing" ) # Debug
409
- linear_model = LinearModel (CORRFLOW_OUTPUT_DIM , args .num_classes )
448
+ linear_model = linear_cls (CORRFLOW_OUTPUT_DIM , args .num_classes )
410
449
elif args .model == "resnet" :
411
450
if not args .not_pretrain :
412
451
resnet = torchvision .models .resnet50 (pretrained = True )
@@ -415,7 +454,7 @@ def main(args):
415
454
print ("ResNet not loading checkpoint" ) # Debug
416
455
modules = list (resnet .children ())[:- 1 ]
417
456
model = nn .Sequential (* modules )
418
- linear_model = LinearModel (RESNET_OUTPUT_DIM , args .num_classes )
457
+ linear_model = linear_cls (RESNET_OUTPUT_DIM , args .num_classes )
419
458
else :
420
459
raise Exception ("model type has to be amdim, ccc, corrflow or resnet" )
421
460
@@ -454,8 +493,9 @@ def main(args):
454
493
455
494
# Set up log dir
456
495
now = datetime .datetime .now ()
457
- log_dir = "{}{:%Y%m%dT%H%M}" .format (args .model , now )
458
- log_dir = os .path .join ("weights" , log_dir )
496
+ log_dir = '/checkpoint/cinjon/spaceofmotion/bsn/cifar-%d-weights/%s/%s' % (args .num_classes , args .model , args .name )
497
+ # log_dir = "{}{:%Y%m%dT%H%M}".format(args.model, now)
498
+ # log_dir = os.path.join("weights", log_dir)
459
499
if not os .path .exists (log_dir ):
460
500
os .makedirs (log_dir )
461
501
print ("Saving to {}" .format (log_dir ))
@@ -557,7 +597,7 @@ def main(args):
557
597
train_acc = 0
558
598
train_loss_sum = 0.0
559
599
for iter , input in enumerate (train_dataloader ):
560
- if time .time () - start_time > args .time * 3600 - 10 and comet_exp is not None :
600
+ if time .time () - start_time > args .time * 3600 - 300 and comet_exp is not None :
561
601
comet_exp .end ()
562
602
sys .exit (- 1 )
563
603
@@ -702,6 +742,10 @@ def main(args):
702
742
val_acc = 0
703
743
val_loss_sum = 0.0
704
744
for iter , input in enumerate (val_dataloader ):
745
+ if time .time () - start_time > args .time * 3600 - 300 and comet_exp is not None :
746
+ comet_exp .end ()
747
+ sys .exit (- 1 )
748
+
705
749
imgs = input [0 ].to (device )
706
750
if args .model != "resnet" :
707
751
imgs = imgs .unsqueeze (1 )
0 commit comments