Skip to content

Commit

Permalink
fix benchmark issues
Browse files Browse the repository at this point in the history
  • Loading branch information
gemfield committed May 25, 2021
1 parent d152328 commit f1c4637
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
10 changes: 5 additions & 5 deletions deepvac/backbones/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def testFly(self):
max_res = torch.max(softmaxs, dim=1)
max_probability, max_index = max_res
LOG.logI("path: {}, max_probability:{}, max_index:{}".format(path[0], max_probability.item(), max_index.item()))
self.config.sample = input_tensor

def auditConfig():
config.core.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Expand All @@ -280,7 +281,7 @@ def auditConfig():
config.core.log_every = 100
config.core.num_workers = 4

config.core.net = ResNet50()
config.core.net = resnet50(pretrained=True)

config.core.optimizer = optim.SGD(
config.core.net.parameters(),
Expand Down Expand Up @@ -334,7 +335,7 @@ def auditConfig():
if op == 'test':
if(len(sys.argv) != 4):
LOG.logE("Usage: python -m deepvac.backbones.resnet test <pretrained_model.pth> <your_test_img_input_dir>", exit=True)

#config.core.network_audit_disabled=True
config.core.model_path = sys.argv[2]
config.cast.ScriptCast.model_dir = "./script.pt"
config.cast.ScriptCast.static_quantize_dir = "./static_quantize.pt"
Expand All @@ -348,11 +349,10 @@ def auditConfig():
if(len(sys.argv) != 4):
LOG.logE("Usage: python -m deepvac.backbones.resnet benchmark <pretrained_model.pth> <your_input_img.jpg>", exit=True)

config.core.model_reinterpret_cast = False
config.core.cast_state_dict_strict = False
# config.core.model_reinterpret_cast = False
# config.core.cast_state_dict_strict = False
# config.core.net_omit_keys = ['num_batches_tracked']
# config.core.net_omit_keys_strict = False
# config.core.network_audit_disabled=False
config.core.model_path = sys.argv[2]
config.core.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
img_path = sys.argv[3]
Expand Down
5 changes: 4 additions & 1 deletion deepvac/core/deepvac.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,10 @@ def process(self, input_tensor):
LOG.logI("You did not provide input with config.core.sample...")

LOG.logI("testFly() is your last chance, you must have already reimplemented testFly() in subclass {}, right?".format(self.name()))
return self.testFly()
x = self.testFly()
if self.config.sample is None:
LOG.logE("You must set self.config.sample in testFly() reimplementation in your subclass {}".format(self.name()), exit=True)
return x

def __call__(self, input_tensor=None):
self.auditConfig()
Expand Down

0 comments on commit f1c4637

Please sign in to comment.