Skip to content

implement NxN conv (N>1) #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.pyc
logs/*
3 changes: 3 additions & 0 deletions amc_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from copy import deepcopy
import torch
torch.backends.cudnn.deterministic = True
import torchvision

from env.channel_pruning_env import ChannelPruningEnv
from lib.agent import DDPG
Expand Down Expand Up @@ -85,6 +86,8 @@ def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1):
elif model == 'mobilenetv2' and dataset == 'imagenet':
from models.mobilenet_v2 import MobileNetV2
net = MobileNetV2(n_class=1000)
elif model == 'vgg16' and dataset == 'imagenet':
net = torchvision.models.vgg16(pretrained=True)
else:
raise NotImplementedError
sd = torch.load(checkpoint_path)
Expand Down
66 changes: 52 additions & 14 deletions env/channel_pruning_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpy as np
import copy
from lib.utils import least_square_sklearn


class ChannelPruningEnv:
Expand All @@ -22,6 +23,7 @@ def __init__(self, model, checkpoint, data, preserve_ratio, args, n_data_worker=
batch_size=256, export_model=False, use_new_input=False):
# default setting
self.prunable_layer_types = [torch.nn.modules.conv.Conv2d, torch.nn.modules.linear.Linear]
self.prunable_layer_types = [torch.nn.modules.conv.Conv2d] #uncomment this line if you want to prune both conv layers and fc layers.

# save options
self.model = model
Expand Down Expand Up @@ -225,14 +227,23 @@ def format_rank(x):
mask[preserve_idx] = True

# reconstruct, X, Y <= [N, C]
masked_X = X[:, mask]
if weight.shape[2] > 1 or weight.shape[3] > 1: # the case of convolutional layers
k_size = int(X.shape[1] / weight.shape[1])
XX = X.reshape((X.shape[0],-1,k_size))
masked_X = XX[:, mask, :]
masked_X = masked_X.reshape((masked_X.shape[0],-1))
else:
masked_X = X[:, mask]
if weight.shape[2] == 1: # 1x1 conv or fc
from lib.utils import least_square_sklearn
rec_weight = least_square_sklearn(X=masked_X, Y=Y)
rec_weight = rec_weight.reshape(-1, 1, 1, d_prime) # (C_out, K_h, K_w, C_in')
rec_weight = np.transpose(rec_weight, (0, 3, 1, 2)) # (C_out, C_in', K_h, K_w)
# rec_weight = rec_weight.reshape(-1, 1, 1, d_prime) # (C_out, K_h, K_w, C_in')
# rec_weight = np.transpose(rec_weight, (0, 3, 1, 2)) # (C_out, C_in', K_h, K_w)
rec_weight = rec_weight.reshape(-1, d_prime, 1, 1)
else:
raise NotImplementedError('Current code only supports 1x1 conv now!')
# raise NotImplementedError('Current code only supports 1x1 conv now!')
_,_,s1,s2 = weight.shape
rec_weight = least_square_sklearn(X=masked_X, Y=Y)
rec_weight = rec_weight.reshape(-1, d_prime, s1, s2) # (C_out, K_h, K_w, C_in')
if not self.export_model: # pad, pseudo compress
rec_weight_pad = np.zeros_like(weight)
rec_weight_pad[:, mask, :, :] = rec_weight
Expand All @@ -252,6 +263,8 @@ def format_rank(x):
m = m_list[idx]
if type(m) == nn.Conv2d: # depthwise
m.weight.data = torch.from_numpy(m.weight.data.cpu().numpy()[mask, :, :, :]).cuda()
if m.bias is not None: # this is necessary for NNs like VGG, where bias is integrated in conv layer
m.bias.data = torch.from_numpy(m.bias.data.cpu().numpy()[mask]).cuda()
if m.groups == m.in_channels:
m.groups = int(np.sum(mask))
elif type(m) == nn.BatchNorm2d:
Expand Down Expand Up @@ -435,21 +448,23 @@ def lambda_forward(x):
if len(f_in_np.shape) == 4: # conv
if self.prunable_idx.index(idx) == 0: # first conv
f_in2save, f_out2save = None, None
elif m_list[idx].weight.size(3) > 1: # normal conv
f_in2save, f_out2save = f_in_np, f_out_np
else: # 1x1 conv
else:
# assert f_out_np.shape[2] == f_in_np.shape[2] # now support k=3
randx = np.random.randint(0, f_out_np.shape[2] - 0, self.n_points_per_layer)
randy = np.random.randint(0, f_out_np.shape[3] - 0, self.n_points_per_layer)
# input: [N, C, H, W]
self.layer_info_dict[idx][(i_b, 'randx')] = randx.copy()
self.layer_info_dict[idx][(i_b, 'randy')] = randy.copy()

f_in2save = f_in_np[:, :, randx, randy].copy().transpose(0, 2, 1)\
.reshape(self.batch_size * self.n_points_per_layer, -1)

f_out2save = f_out_np[:, :, randx, randy].copy().transpose(0, 2, 1) \
.reshape(self.batch_size * self.n_points_per_layer, -1)
if m_list[idx].weight.size(3) > 1: # NxN conv (N>1)
f_in2save = self.subfunc_generate_input_features(m_list, idx, i_b, f_in_np)
ww = m_list[idx].weight.cpu().detach().numpy()
ww = ww.reshape((ww.shape[0],-1)).T
f_out2save = f_in2save.dot(ww) # corresponding f_out can be computed by simple multiplication
else: # 1x1 conv
f_in2save = f_in_np[:, :, randx, randy].copy().transpose(0, 2, 1)\
.reshape(self.batch_size * self.n_points_per_layer, -1)
f_out2save = f_out_np[:, :, randx, randy].copy().transpose(0, 2, 1) \
.reshape(self.batch_size * self.n_points_per_layer, -1)
else:
assert len(f_in_np.shape) == 2
f_in2save = f_in_np.copy()
Expand Down Expand Up @@ -487,6 +502,8 @@ def _regenerate_input_feature(self):
if len(f_in_np.shape) == 4: # conv
if self.prunable_idx.index(idx) == 0: # first conv
f_in2save = None
elif m_list[idx].weight.size(3) > 1: # NxN conv (N>1)
f_in2save = self.subfunc_generate_input_features(m_list, idx, i_b, f_in_np)
else:
randx = self.layer_info_dict[idx][(i_b, 'randx')]
randy = self.layer_info_dict[idx][(i_b, 'randy')]
Expand All @@ -501,6 +518,27 @@ def _regenerate_input_feature(self):
self.layer_info_dict[idx]['input_feat'] = np.vstack(
(self.layer_info_dict[idx]['input_feat'], f_in2save))

def subfunc_generate_input_features(self, m_list, idx, i_b, f_in_np):
# generate input features for NxN conv, where N>1. although it should work for 1x1 conv as well.
_, _, s1, s2 = m_list[idx].weight.shape # kernel shape
c2, c1, t1, t2 = f_in_np.shape # feature map shape
randx = self.layer_info_dict[idx][(i_b, 'randx')]
randy = self.layer_info_dict[idx][(i_b, 'randy')]
ph = int(s1/2) # horizontal padding
pw = int(s2/2) # vertical padding
f_in_np_pad = np.pad(f_in_np,((0,0),(0,0),(ph,ph),(pw,pw)))
f_in2save = 0 # initialized
for kk in range(self.n_points_per_layer):
j1 = randx[kk]
j2 = randy[kk]
x = f_in_np_pad[:, :, j1:j1+s1, j2:j2+s2]
x = x.reshape([x.shape[0],-1])
try :
f_in2save = np.append(f_in2save, x, axis=0)
except :
f_in2save = x
return f_in2save

def _build_state_embedding(self):
# build the static part of the state embedding
layer_embedding = []
Expand Down
1 change: 1 addition & 0 deletions scripts/search_vgg16_0.5flops.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python amc_search.py --job=train --model=vgg16 --ckpt_path=path_to_checkpoint --dataset=imagenet --data_root=path_to_data_root --preserve_ratio=0.5 --lbound=0.2 --rbound=1 --reward=acc_reward --n_calibration_batches=60 --seed=2018