Skip to content

Commit a40fd26

Browse files
committed
fix benchmark
1 parent 0472ae4 commit a40fd26

3 files changed

Lines changed: 54 additions & 10 deletions

File tree

deepvac/backbones/resnet.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def testFly(self):
270270

271271
def auditConfig():
272272
config.core.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
273+
config.cast.ScriptCast = AttrDict()
273274
config.cast.ScriptCast.model_dir = "./gemfield_script.pt"
274275

275276
config.core.disable_git = True
@@ -279,7 +280,7 @@ def auditConfig():
279280
config.core.log_every = 100
280281
config.core.num_workers = 4
281282

282-
config.core.net = ResNet50() # ResNet50() / resnet50()
283+
config.core.net = ResNet50()
283284

284285
config.core.optimizer = optim.SGD(
285286
config.core.net.parameters(),
@@ -289,6 +290,7 @@ def auditConfig():
289290
nesterov=False
290291
)
291292
config.core.scheduler = optim.lr_scheduler.MultiStepLR(config.core.optimizer, [50, 70, 90], 0.1)
293+
config.core.criterion=torch.nn.CrossEntropyLoss()
292294

293295
config.core.shuffle = True
294296
config.core.batch_size = 1
@@ -325,6 +327,7 @@ def auditConfig():
325327

326328
config.core.val_dataset = FileLineDataset(config, fileline_path=sys.argv[5], sample_path_prefix=sys.argv[3])
327329
config.core.val_loader = torch.utils.data.DataLoader(config.core.val_dataset, batch_size=1, pin_memory=False)
330+
config.core.test_loader = ''
328331
train = ResNet50Train(config)
329332
train()
330333

@@ -333,18 +336,23 @@ def auditConfig():
333336
LOG.logE("Usage: python -m deepvac.backbones.resnet test <pretrained_model.pth> <your_test_img_input_dir>", exit=True)
334337

335338
config.core.model_path = sys.argv[2]
336-
config.cast.ScriptCast.model_dir = ""
339+
config.cast.ScriptCast.model_dir = "./script.pt"
337340
config.cast.ScriptCast.static_quantize_dir = "./static_quantize.pt"
338341
config.core.test_dataset = ResnetClsTestDataset(config, sample_path=sys.argv[3])
339342
config.core.test_loader = torch.utils.data.DataLoader(config.core.test_dataset, batch_size=1, pin_memory=False)
340343
test = ResNet50Test(config)
341344
input_tensor = torch.rand(1,3,640,640)
342-
test(input_tensor)
345+
test()
343346

344347
if op == 'benchmark':
345348
if(len(sys.argv) != 4):
346349
LOG.logE("Usage: python -m deepvac.backbones.resnet benchmark <pretrained_model.pth> <your_input_img.jpg>", exit=True)
347350

351+
config.core.model_reinterpret_cast = False
352+
config.core.cast_state_dict_strict = False
353+
# config.core.net_omit_keys = ['num_batches_tracked']
354+
# config.core.net_omit_keys_strict = False
355+
# config.core.network_audit_disabled=False
348356
config.core.model_path = sys.argv[2]
349357
config.core.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
350358
img_path = sys.argv[3]

deepvac/core/deepvac.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
except ImportError:
1919
LOG.logE("Deepvac has dependency on tensorboard, please install tensorboard first, e.g. [pip3 install tensorboard]", exit=True)
2020

21-
from ..utils import syszux_once, LOG, assertAndGetGitBranch, getPrintTime, AverageMeter
21+
from ..utils import syszux_once, LOG, assertAndGetGitBranch, getPrintTime, AverageMeter, anyFieldsInConfig
2222
from ..cast import export3rd
2323

2424
#deepvac implemented based on PyTorch Framework
@@ -108,15 +108,38 @@ def initStateDict(self):
108108

109109
def castStateDict(self, config):
110110
LOG.logI("model_reinterpret_cast set to True in config.py, Try to reinterpret cast the model")
111+
112+
if self.config.model_path_omit_keys:
113+
LOG.logI("You have set config.core.model_path_omit_keys: {}".format(self.config.model_path_omit_keys))
114+
for k in self.config.model_path_omit_keys:
115+
LOG.logI("remove key {} from config.core.model_path {}".format(k, self.config.model_path))
116+
config.state_dict.pop(k, None)
117+
111118
state_dict = collections.OrderedDict()
112119
keys = list(config.state_dict.keys())
113-
for idx, name in enumerate(config.net.state_dict()):
114-
if config.net.state_dict()[name].size() == config.state_dict[keys[idx]].size():
115-
LOG.logI("cast pretrained model [{}] => config.core.net [{}]".format(keys[idx], name))
116-
state_dict[name] = config.state_dict[keys[idx]]
120+
model_path_keys_len = len(keys)
121+
if len(config.net.state_dict() ) > model_path_keys_len:
122+
LOG.logW("config.core.net has more parameters than config.core.model_path({}), may has cast issues.".format(self.config.model_path))
123+
124+
real_idx = 0
125+
for _, name in enumerate(config.net.state_dict()):
126+
if real_idx >= model_path_keys_len:
127+
LOG.logI("There alreay has no corresponding parameter in {} for {}".format(self.config.model_path, name))
128+
continue
129+
130+
if anyFieldsInConfig(name, self.config.net_omit_keys, self.config.net_omit_keys_strict):
131+
LOG.logI('found key to omit in config.core.net: {}, continue...'.format(name))
117132
continue
118-
LOG.logE("cannot cast pretrained model [{}] => config.net [{}] due to parameter shape mismatch!".format(keys[idx], name))
133+
134+
if config.net.state_dict()[name].size() == config.state_dict[keys[real_idx]].size():
135+
LOG.logI("cast pretrained model [{}] => config.core.net [{}]".format(keys[real_idx], name))
136+
state_dict[name] = config.state_dict[keys[real_idx]]
137+
real_idx += 1
138+
continue
139+
140+
LOG.logE("cannot cast pretrained model [{}] => config.net [{}] due to parameter shape mismatch!".format(keys[real_idx], name))
119141
if config.cast_state_dict_strict is False:
142+
real_idx += 1
120143
continue
121144
LOG.logE("If you know above risk, set cast_state_dict_strict=False in config.py to omit this audit.", exit=True)
122145

deepvac/utils/user_config.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,17 @@ def addUserConfig(module_name, config_name, user_give=None, developer_give=None,
77
return user_give
88
if developer_give is not None:
99
return developer_give
10-
LOG.logE("value missing for configuration: {}.{} in config.py".format(module_name, config_name), exit=True)
10+
LOG.logE("value missing for configuration: {}.{} in config.py".format(module_name, config_name), exit=True)
11+
12+
def anyFieldsInConfig(name, c, c_strict=True):
13+
if not c:
14+
return False
15+
16+
if not isinstance(c, list):
17+
return False
18+
19+
fields = name.split('.') if c_strict is False else [name]
20+
for field in fields:
21+
if field in c:
22+
return True
23+
return False

0 commit comments

Comments
 (0)