@@ -270,6 +270,7 @@ def testFly(self):
270270
271271def auditConfig ():
272272 config .core .device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
273+ config .cast .ScriptCast = AttrDict ()
273274 config .cast .ScriptCast .model_dir = "./gemfield_script.pt"
274275
275276 config .core .disable_git = True
@@ -279,7 +280,7 @@ def auditConfig():
279280 config .core .log_every = 100
280281 config .core .num_workers = 4
281282
282- config .core .net = ResNet50 () # ResNet50() / resnet50()
283+ config .core .net = ResNet50 ()
283284
284285 config .core .optimizer = optim .SGD (
285286 config .core .net .parameters (),
@@ -289,6 +290,7 @@ def auditConfig():
289290 nesterov = False
290291 )
291292 config .core .scheduler = optim .lr_scheduler .MultiStepLR (config .core .optimizer , [50 , 70 , 90 ], 0.1 )
293+ config .core .criterion = torch .nn .CrossEntropyLoss ()
292294
293295 config .core .shuffle = True
294296 config .core .batch_size = 1
@@ -325,6 +327,7 @@ def auditConfig():
325327
326328 config .core .val_dataset = FileLineDataset (config , fileline_path = sys .argv [5 ], sample_path_prefix = sys .argv [3 ])
327329 config .core .val_loader = torch .utils .data .DataLoader (config .core .val_dataset , batch_size = 1 , pin_memory = False )
330+ config .core .test_loader = ''
328331 train = ResNet50Train (config )
329332 train ()
330333
@@ -333,18 +336,23 @@ def auditConfig():
333336 LOG .logE ("Usage: python -m deepvac.backbones.resnet test <pretrained_model.pth> <your_test_img_input_dir>" , exit = True )
334337
335338 config .core .model_path = sys .argv [2 ]
336- config .cast .ScriptCast .model_dir = ""
339+ config .cast .ScriptCast .model_dir = "./script.pt "
337340 config .cast .ScriptCast .static_quantize_dir = "./static_quantize.pt"
338341 config .core .test_dataset = ResnetClsTestDataset (config , sample_path = sys .argv [3 ])
339342 config .core .test_loader = torch .utils .data .DataLoader (config .core .test_dataset , batch_size = 1 , pin_memory = False )
340343 test = ResNet50Test (config )
341344 input_tensor = torch .rand (1 ,3 ,640 ,640 )
342- test (input_tensor )
345+ test ()
343346
344347 if op == 'benchmark' :
345348 if (len (sys .argv ) != 4 ):
346349 LOG .logE ("Usage: python -m deepvac.backbones.resnet benchmark <pretrained_model.pth> <your_input_img.jpg>" , exit = True )
347350
351+ config .core .model_reinterpret_cast = False
352+ config .core .cast_state_dict_strict = False
353+ # config.core.net_omit_keys = ['num_batches_tracked']
354+ # config.core.net_omit_keys_strict = False
355+ # config.core.network_audit_disabled=False
348356 config .core .model_path = sys .argv [2 ]
349357 config .core .device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
350358 img_path = sys .argv [3 ]
0 commit comments