diff --git a/efficientnet_pytorch/model.py b/efficientnet_pytorch/model.py old mode 100755 new mode 100644 index e89e97b5..a3450b33 --- a/efficientnet_pytorch/model.py +++ b/efficientnet_pytorch/model.py @@ -40,6 +40,7 @@ class MBConvBlock(nn.Module): block_args (namedtuple): BlockArgs, defined in utils.py. global_params (namedtuple): GlobalParam, defined in utils.py. image_size (tuple or list): [image_height, image_width]. + decoder_mode (bool): Reverse the block (deconvolution) if true. References: [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) @@ -47,19 +48,20 @@ class MBConvBlock(nn.Module): [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) """ - def __init__(self, block_args, global_params, image_size=None): + def __init__(self, block_args, global_params, image_size=None, decoder_mode=False, decoder_output_image_size=None): super().__init__() self._block_args = block_args self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow self._bn_eps = global_params.batch_norm_epsilon self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) self.id_skip = block_args.id_skip # whether to use skip connection and drop connect + self.decoder_mode = decoder_mode # Expansion phase (Inverted Bottleneck) inp = self._block_args.input_filters # number of input channels oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels if self._block_args.expand_ratio != 1: - Conv2d = get_same_padding_conv2d(image_size=image_size) + Conv2d = get_same_padding_conv2d(image_size=image_size, transposed=self.decoder_mode) self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size @@ -67,23 +69,27 @@ def __init__(self, block_args, global_params, image_size=None): # Depthwise convolution phase k = self._block_args.kernel_size s = self._block_args.stride - Conv2d = get_same_padding_conv2d(image_size=image_size) + if self.decoder_mode: + # assert decoder_output_image_size + image_size = decoder_output_image_size + Conv2d = get_same_padding_conv2d(image_size=image_size, transposed=self.decoder_mode) self._depthwise_conv = Conv2d( in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise kernel_size=k, stride=s, bias=False) self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) - image_size = calculate_output_image_size(image_size, s) + if not self.decoder_mode: + image_size = calculate_output_image_size(image_size, s) # Squeeze and Excitation layer, if desired if self.has_se: - Conv2d = get_same_padding_conv2d(image_size=(1, 1)) + Conv2d = get_same_padding_conv2d(image_size=(1, 1), transposed=self.decoder_mode) num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) # Pointwise convolution phase final_oup = self._block_args.output_filters - Conv2d = get_same_padding_conv2d(image_size=image_size) + Conv2d = get_same_padding_conv2d(image_size=image_size, transposed=self.decoder_mode) self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) self._swish = MemoryEfficientSwish() @@ -152,9 +158,7 @@ class EfficientNet(nn.Module): [1] https://arxiv.org/abs/1905.11946 (EfficientNet) Example: - - - import torch + >>> import torch >>> from efficientnet.model import EfficientNet >>> inputs = torch.rand(1, 3, 224, 224) >>> model = EfficientNet.from_pretrained('efficientnet-b0') @@ -170,8 +174,8 @@ def __init__(self, blocks_args=None, global_params=None): self._blocks_args = blocks_args # Batch norm parameters - bn_mom = 1 - self._global_params.batch_norm_momentum - bn_eps = self._global_params.batch_norm_epsilon + self._bn_mom = bn_mom = 1 - self._global_params.batch_norm_momentum + self._bn_eps = bn_eps = self._global_params.batch_norm_epsilon # Get stem static or dynamic convolution depending on image size image_size = global_params.image_size @@ -186,6 +190,7 @@ def __init__(self, blocks_args=None, global_params=None): # Build blocks self._blocks = nn.ModuleList([]) + self._blocks_image_size = [image_size] for block_args in self._blocks_args: # Update block input and output filters based on depth multiplier. @@ -198,6 +203,7 @@ def __init__(self, blocks_args=None, global_params=None): # The first block needs to take care of stride and filter size increase. self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) image_size = calculate_output_image_size(image_size, block_args.stride) + self._blocks_image_size.append(image_size) if block_args.num_repeat > 1: # modify block_args to keep same output size block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) for _ in range(block_args.num_repeat - 1): @@ -217,6 +223,10 @@ def __init__(self, blocks_args=None, global_params=None): self._fc = nn.Linear(out_channels, self._global_params.num_classes) self._swish = MemoryEfficientSwish() + self._image_size = image_size + self._last_block_args = block_args + self._last_out_channels = out_channels + def set_swish(self, memory_efficient=True): """Sets swish function as memory efficient (for training) or standard (for export). @@ -239,6 +249,8 @@ def extract_endpoints(self, inputs): Dictionary of last intermediate features with reduction levels i in [1, 2, 3, 4, 5]. Example: + + >>> import torch >>> from efficientnet.model import EfficientNet >>> inputs = torch.rand(1, 3, 224, 224) @@ -284,7 +296,6 @@ def extract_features(self, inputs): """ # Stem x = self._swish(self._bn0(self._conv_stem(inputs))) - # Blocks for idx, block in enumerate(self._blocks): drop_connect_rate = self._global_params.drop_connect_rate @@ -294,7 +305,6 @@ def extract_features(self, inputs): # Head x = self._swish(self._bn1(self._conv_head(x))) - return x def forward(self, inputs): @@ -309,6 +319,7 @@ def forward(self, inputs): """ # Convolution layers x = self.extract_features(inputs) + # Pooling and final linear layer x = self._avg_pooling(x) if self._global_params.include_top: @@ -413,3 +424,148 @@ def _change_in_channels(self, in_channels): Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size) out_channels = round_filters(32, self._global_params) self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) + +class EfficientNetAutoEncoder(EfficientNet): + """EfficientNet AutoEncoder model. + Most easily loaded with the .from_name or .from_pretrained methods. + + Args: + blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. + global_params (namedtuple): A set of GlobalParams shared between blocks. + + References: + [1] https://arxiv.org/abs/1905.11946 (EfficientNet) + + Example: + + + >>> import torch + >>> from efficientnet.model import EfficientNet + >>> inputs = torch.rand(1, 3, 224, 224) + >>> model = EfficientNetAutoEncoder.from_pretrained('efficientnet-b0') + >>> model.eval() + >>> ae_output, latent_fc_output = model(inputs) + """ + + def __init__(self, blocks_args=None, global_params=None): + super().__init__(blocks_args=blocks_args, global_params=global_params) + bn_mom = self._bn_mom + bn_eps = self._bn_eps + image_size = self._image_size + block_args = self._last_block_args + + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._feature_downsample = Conv2d(self._last_out_channels, 8, kernel_size=1, bias=False) + self._downsample_bn = nn.BatchNorm2d(num_features=8, momentum=bn_mom, eps=bn_eps) + self._feature_upsample = Conv2d(8, self._last_out_channels, kernel_size=1, bias=False) + self._upsample_bn = nn.BatchNorm2d(num_features=self._last_out_channels, momentum=bn_mom, eps=bn_eps) + self.feature_size = 8 * image_size[0]**2 + + # EfficientNet Decoder + # use dynamic image size for decoder + TransposedConv2d = get_same_padding_conv2d(image_size=image_size, transposed=True) + + # Stem + # self._decoder_conv_stem symmetry to self._conv_head + in_channels = round_filters(1280, self._global_params) + out_channels = block_args.output_filters # output of final block + self._decoder_conv_stem = TransposedConv2d(in_channels, out_channels, kernel_size=1, bias=False) + image_size = calculate_output_image_size(image_size, 1, transposed=True) + self._decoder_bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + # image_size = calculate_output_image_size(image_size, 2) + + # Build blocks + self._decoder_blocks = nn.ModuleList([]) + assert len(self._blocks_image_size) == len(self._blocks_args) + 1 + self._blocks_image_size = list(reversed(self._blocks_image_size)) + for i, block_args in enumerate(reversed(self._blocks_args)): + image_size = self._blocks_image_size[i] + # Update block input and output filters based on depth multiplier. + # input/output are flip here to support deconvolution + block_args = block_args._replace( + input_filters=round_filters(block_args.output_filters, self._global_params), + output_filters=round_filters(block_args.input_filters, self._global_params), + num_repeat=round_repeats(block_args.num_repeat, self._global_params) + ) + # The first block needs to take care of stride and filter size increase. + self._decoder_blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size, + decoder_mode=True, decoder_output_image_size=self._blocks_image_size[i+1])) + image_size = self._blocks_image_size[i+1] + if block_args.num_repeat > 1: # modify block_args to keep same output size + block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + self._decoder_blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size, + decoder_mode=True, decoder_output_image_size=image_size)) + # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 + + # Head + in_channels = round_filters(32, self._global_params) # number of output channels + out_channels = 3 # rgb + TransposedConv2d = get_same_padding_conv2d(image_size=global_params.image_size, transposed=True) + self._decoder_conv_head = TransposedConv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) + self._decoder_bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + + def extract_features(self, inputs): + """use convolution layer to extract feature, + with additional down-sample layer to get 1280 hidden feature. + + Args: + inputs (tensor): Input tensor. + + Returns: + Output of the final convolution + layer in the efficientnet model. + """ + x = super().extract_features(inputs) + x = self._swish(self._downsample_bn(self._feature_downsample(x))) + return x + + + def decode_features(self, inputs): + """decoder portion of this autoencoder. + + Args: + inputs (tensor): Input tensor to the decoder, + usually from self.extract_features + + Returns: + Output of the final convolution + layer in the efficientnet model. + """ + # upsample + x = self._swish(self._upsample_bn(self._feature_upsample(inputs))) + # Stem + x = self._swish(self._decoder_bn0(self._decoder_conv_stem(x))) + # Blocks + for idx, block in enumerate(self._decoder_blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + # scale drop connect_rate + drop_connect_rate *= float(idx) / len(self._blocks) + x = block(x, drop_connect_rate=drop_connect_rate) + + # Head + x = self._swish(self._decoder_bn1(self._decoder_conv_head(x))) + return x + + + def forward(self, inputs): + """EfficientNet AutoEncoder's forward function. + Calls extract_features to extract features, + then calls decode features to generates original inputs. + + Args: + inputs (tensor): Input tensor. + + Returns: + (AE output tensor, latent representation tensor) + """ + # Convolution layers + x = self.extract_features(inputs) + + # Pooling and final linear layer + latent_rep = x.flatten(start_dim=1) + + # Deconvolution - decoder + x = self.decode_features(x) + return x, latent_rep \ No newline at end of file diff --git a/efficientnet_pytorch/utils.py b/efficientnet_pytorch/utils.py old mode 100755 new mode 100644 index 6a843458..42bc6568 --- a/efficientnet_pytorch/utils.py +++ b/efficientnet_pytorch/utils.py @@ -167,7 +167,7 @@ def get_width_and_height_from_size(x): raise TypeError() -def calculate_output_image_size(input_image_size, stride): +def calculate_output_image_size(input_image_size, stride, transposed=False): """Calculates the output image size when using Conv2dSamePadding with a stride. Necessary for static padding. Thanks to mannatsingh for pointing this out. @@ -182,8 +182,12 @@ def calculate_output_image_size(input_image_size, stride): return None image_height, image_width = get_width_and_height_from_size(input_image_size) stride = stride if isinstance(stride, int) else stride[0] - image_height = int(math.ceil(image_height / stride)) - image_width = int(math.ceil(image_width / stride)) + if transposed: + image_height = int(image_height * stride) + image_width = int(image_width * stride) + else: + image_height = int(math.ceil(image_height / stride)) + image_width = int(math.ceil(image_width / stride)) return [image_height, image_width] @@ -192,16 +196,23 @@ def calculate_output_image_size(input_image_size, stride): # Only when stride equals 1, can the output size be the same as input size. # Don't be confused by their function names ! ! ! -def get_same_padding_conv2d(image_size=None): +def get_same_padding_conv2d(image_size=None, transposed=False): """Chooses static padding if you have specified an image size, and dynamic padding otherwise. Static padding is necessary for ONNX exporting of models. Args: image_size (int or tuple): Size of the image. + transposed (bool): use nn.functional.conv_transpose2d if true, and nn.functional.conv2d otherwise. Returns: Conv2dDynamicSamePadding or Conv2dStaticSamePadding. """ + if transposed: + if image_size is None: + raise NotImplementedError('Unable to dynamically upsample to odd image size.') + else: + return partial(TransposedConv2dStaticSamePadding, image_size=image_size) + if image_size is None: return Conv2dDynamicSamePadding else: @@ -271,6 +282,55 @@ def forward(self, x): x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return x +class TransposedConv2dStaticSamePadding(nn.ConvTranspose2d): + """2D Convolutions like TensorFlow, for a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + # Tips for 'SAME' mode padding. + # Given the following: + # i: width or height + # s: stride + # k: kernel size + # d: dilation + # p: padding + # op: output padding + # Output after ConvTranspose2d: + # (i-1)*s + (k-1)*d + op + 1 + + def __init__(self, in_channels, out_channels, kernel_size, image_size, stride=1, output_padding=0, groups=1, bias=True, dilation=1): + super().__init__(in_channels, out_channels, kernel_size, stride, 0, output_padding, groups, bias, dilation) + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 + self.output_padding = output_padding + # NOTE: image_size here represents the desired output image_size + oh, ow = (image_size, image_size) if isinstance(image_size, int) else image_size + self._oh, self._ow = oh, ow + sh, sw = self.stride + ih, iw = math.ceil(oh / sh), math.ceil(ow / sw) # using same calculation in Conv2dStaticSamePadding + self._ih, self._iw = ih, iw + kh, kw = self.weight.size()[-2:] + # actual height/width after TransposedConv2d + actual_oh = (ih - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + self.output_padding + 1 + actual_ow = (iw - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + self.output_padding + 1 + crop_h = actual_oh - oh + crop_w = actual_ow - ow + assert crop_h >= 0 and crop_w >= 0 + self._crop_h = crop_h + self._crop_w = crop_w + self._actual_oh = actual_oh + self._actual_ow = actual_ow + + def forward(self, x): + # assert x.size()[-2:] == (self._ih, self._iw) + x = F.conv_transpose2d(x, self.weight, self.bias, self.stride, self.padding, + self.output_padding, self.groups, self.dilation) + # assert x.size()[-2:] == (self._actual_oh, self._actual_ow) + crop_h, crop_w = self._crop_h, self._crop_w + if crop_h > 0 or crop_w > 0: + x = x[:, :, crop_h // 2 : - (crop_h - crop_h // 2), crop_w // 2 : - (crop_w - crop_w // 2)] + # assert x.size()[-2:] == (self._oh, self._ow) + return x + def get_same_padding_maxPool2d(image_size=None): """Chooses static padding if you have specified an image size, and dynamic padding otherwise. @@ -598,13 +658,23 @@ def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, if load_fc: ret = model.load_state_dict(state_dict, strict=False) - assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) + + # weights for decoder are not loaded + # TODO: add initialization to missing layers + missing_keys = [] + for key in ret.missing_keys: + if not key.startswith(('_decoder', '_feature', '_upsample', '_downsample')): + missing_keys.append(key) + + assert not missing_keys, 'Missing keys when loading pretrained weights: {}'.format( + missing_keys) else: state_dict.pop('_fc.weight') state_dict.pop('_fc.bias') ret = model.load_state_dict(state_dict, strict=False) assert set(ret.missing_keys) == set( ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) - assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys) + assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format( + ret.unexpected_keys) print('Loaded pretrained weights for {}'.format(model_name))