From 6ccc2cd2a00ac4bd3f148ed8e0fdc6692614cd53 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 25 May 2021 06:54:59 +0000 Subject: [PATCH] allow to load pretrained weights when the include_top variable is False --- efficientnet_pytorch/model.py | 2 +- efficientnet_pytorch/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/efficientnet_pytorch/model.py b/efficientnet_pytorch/model.py index ce850cd..5f4ee78 100755 --- a/efficientnet_pytorch/model.py +++ b/efficientnet_pytorch/model.py @@ -376,7 +376,7 @@ def from_pretrained(cls, model_name, weights_path=None, advprop=False, """ model = cls.from_name(model_name, num_classes=num_classes, **override_params) load_pretrained_weights(model, model_name, weights_path=weights_path, - load_fc=(num_classes == 1000), advprop=advprop) + load_fc=(num_classes == 1000) and model._global_params.include_top, advprop=advprop) model._change_in_channels(in_channels) return model diff --git a/efficientnet_pytorch/utils.py b/efficientnet_pytorch/utils.py index 826a627..c95317e 100755 --- a/efficientnet_pytorch/utils.py +++ b/efficientnet_pytorch/utils.py @@ -608,7 +608,7 @@ def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, state_dict.pop('_fc.weight') state_dict.pop('_fc.bias') ret = model.load_state_dict(state_dict, strict=False) - assert set(ret.missing_keys) == set( + assert not ret.missing_keys or set(ret.missing_keys) == set( ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)