From f1c4637eccf051851333f89ea0c43fb57bea88d8 Mon Sep 17 00:00:00 2001 From: Gemfield Date: Tue, 25 May 2021 15:02:49 +0800 Subject: [PATCH] fix benchmark issues --- deepvac/backbones/resnet.py | 10 +++++----- deepvac/core/deepvac.py | 5 ++++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/deepvac/backbones/resnet.py b/deepvac/backbones/resnet.py index ff9161b..70960c0 100644 --- a/deepvac/backbones/resnet.py +++ b/deepvac/backbones/resnet.py @@ -267,6 +267,7 @@ def testFly(self): max_res = torch.max(softmaxs, dim=1) max_probability, max_index = max_res LOG.logI("path: {}, max_probability:{}, max_index:{}".format(path[0], max_probability.item(), max_index.item())) + self.config.sample = input_tensor def auditConfig(): config.core.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -280,7 +281,7 @@ def auditConfig(): config.core.log_every = 100 config.core.num_workers = 4 - config.core.net = ResNet50() + config.core.net = resnet50(pretrained=True) config.core.optimizer = optim.SGD( config.core.net.parameters(), @@ -334,7 +335,7 @@ def auditConfig(): if op == 'test': if(len(sys.argv) != 4): LOG.logE("Usage: python -m deepvac.backbones.resnet test ", exit=True) - + #config.core.network_audit_disabled=True config.core.model_path = sys.argv[2] config.cast.ScriptCast.model_dir = "./script.pt" config.cast.ScriptCast.static_quantize_dir = "./static_quantize.pt" @@ -348,11 +349,10 @@ def auditConfig(): if(len(sys.argv) != 4): LOG.logE("Usage: python -m deepvac.backbones.resnet benchmark ", exit=True) - config.core.model_reinterpret_cast = False - config.core.cast_state_dict_strict = False + # config.core.model_reinterpret_cast = False + # config.core.cast_state_dict_strict = False # config.core.net_omit_keys = ['num_batches_tracked'] # config.core.net_omit_keys_strict = False - # config.core.network_audit_disabled=False config.core.model_path = sys.argv[2] config.core.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') img_path = sys.argv[3] diff --git a/deepvac/core/deepvac.py b/deepvac/core/deepvac.py index 276e433..9020ee3 100644 --- a/deepvac/core/deepvac.py +++ b/deepvac/core/deepvac.py @@ -261,7 +261,10 @@ def process(self, input_tensor): LOG.logI("You did not provide input with config.core.sample...") LOG.logI("testFly() is your last chance, you must have already reimplemented testFly() in subclass {}, right?".format(self.name())) - return self.testFly() + x = self.testFly() + if self.config.sample is None: + LOG.logE("You must set self.config.sample in testFly() reimplementation in your subclass {}".format(self.name()), exit=True) + return x def __call__(self, input_tensor=None): self.auditConfig()