diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a25817a --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.pyc +logs/* diff --git a/amc_search.py b/amc_search.py index c252bfe..a80c78e 100755 --- a/amc_search.py +++ b/amc_search.py @@ -8,6 +8,7 @@ from copy import deepcopy import torch torch.backends.cudnn.deterministic = True +import torchvision from env.channel_pruning_env import ChannelPruningEnv from lib.agent import DDPG @@ -85,6 +86,8 @@ def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1): elif model == 'mobilenetv2' and dataset == 'imagenet': from models.mobilenet_v2 import MobileNetV2 net = MobileNetV2(n_class=1000) + elif model == 'vgg16' and dataset == 'imagenet': + net = torchvision.models.vgg16(pretrained=True) else: raise NotImplementedError sd = torch.load(checkpoint_path) diff --git a/env/channel_pruning_env.py b/env/channel_pruning_env.py index c6e3abb..c16e52a 100755 --- a/env/channel_pruning_env.py +++ b/env/channel_pruning_env.py @@ -12,6 +12,7 @@ import numpy as np import copy +from lib.utils import least_square_sklearn class ChannelPruningEnv: @@ -22,6 +23,7 @@ def __init__(self, model, checkpoint, data, preserve_ratio, args, n_data_worker= batch_size=256, export_model=False, use_new_input=False): # default setting self.prunable_layer_types = [torch.nn.modules.conv.Conv2d, torch.nn.modules.linear.Linear] + self.prunable_layer_types = [torch.nn.modules.conv.Conv2d] #uncomment this line if you want to prune both conv layers and fc layers. # save options self.model = model @@ -225,14 +227,23 @@ def format_rank(x): mask[preserve_idx] = True # reconstruct, X, Y <= [N, C] - masked_X = X[:, mask] + if weight.shape[2] > 1 or weight.shape[3] > 1: # the case of convolutional layers + k_size = int(X.shape[1] / weight.shape[1]) + XX = X.reshape((X.shape[0],-1,k_size)) + masked_X = XX[:, mask, :] + masked_X = masked_X.reshape((masked_X.shape[0],-1)) + else: + masked_X = X[:, mask] if weight.shape[2] == 1: # 1x1 conv or fc - from lib.utils import least_square_sklearn rec_weight = least_square_sklearn(X=masked_X, Y=Y) - rec_weight = rec_weight.reshape(-1, 1, 1, d_prime) # (C_out, K_h, K_w, C_in') - rec_weight = np.transpose(rec_weight, (0, 3, 1, 2)) # (C_out, C_in', K_h, K_w) + # rec_weight = rec_weight.reshape(-1, 1, 1, d_prime) # (C_out, K_h, K_w, C_in') + # rec_weight = np.transpose(rec_weight, (0, 3, 1, 2)) # (C_out, C_in', K_h, K_w) + rec_weight = rec_weight.reshape(-1, d_prime, 1, 1) else: - raise NotImplementedError('Current code only supports 1x1 conv now!') + # raise NotImplementedError('Current code only supports 1x1 conv now!') + _,_,s1,s2 = weight.shape + rec_weight = least_square_sklearn(X=masked_X, Y=Y) + rec_weight = rec_weight.reshape(-1, d_prime, s1, s2) # (C_out, K_h, K_w, C_in') if not self.export_model: # pad, pseudo compress rec_weight_pad = np.zeros_like(weight) rec_weight_pad[:, mask, :, :] = rec_weight @@ -252,6 +263,8 @@ def format_rank(x): m = m_list[idx] if type(m) == nn.Conv2d: # depthwise m.weight.data = torch.from_numpy(m.weight.data.cpu().numpy()[mask, :, :, :]).cuda() + if m.bias is not None: # this is necessary for NNs like VGG, where bias is integrated in conv layer + m.bias.data = torch.from_numpy(m.bias.data.cpu().numpy()[mask]).cuda() if m.groups == m.in_channels: m.groups = int(np.sum(mask)) elif type(m) == nn.BatchNorm2d: @@ -435,21 +448,23 @@ def lambda_forward(x): if len(f_in_np.shape) == 4: # conv if self.prunable_idx.index(idx) == 0: # first conv f_in2save, f_out2save = None, None - elif m_list[idx].weight.size(3) > 1: # normal conv - f_in2save, f_out2save = f_in_np, f_out_np - else: # 1x1 conv + else: # assert f_out_np.shape[2] == f_in_np.shape[2] # now support k=3 randx = np.random.randint(0, f_out_np.shape[2] - 0, self.n_points_per_layer) randy = np.random.randint(0, f_out_np.shape[3] - 0, self.n_points_per_layer) # input: [N, C, H, W] self.layer_info_dict[idx][(i_b, 'randx')] = randx.copy() self.layer_info_dict[idx][(i_b, 'randy')] = randy.copy() - - f_in2save = f_in_np[:, :, randx, randy].copy().transpose(0, 2, 1)\ - .reshape(self.batch_size * self.n_points_per_layer, -1) - - f_out2save = f_out_np[:, :, randx, randy].copy().transpose(0, 2, 1) \ - .reshape(self.batch_size * self.n_points_per_layer, -1) + if m_list[idx].weight.size(3) > 1: # NxN conv (N>1) + f_in2save = self.subfunc_generate_input_features(m_list, idx, i_b, f_in_np) + ww = m_list[idx].weight.cpu().detach().numpy() + ww = ww.reshape((ww.shape[0],-1)).T + f_out2save = f_in2save.dot(ww) # corresponding f_out can be computed by simple multiplication + else: # 1x1 conv + f_in2save = f_in_np[:, :, randx, randy].copy().transpose(0, 2, 1)\ + .reshape(self.batch_size * self.n_points_per_layer, -1) + f_out2save = f_out_np[:, :, randx, randy].copy().transpose(0, 2, 1) \ + .reshape(self.batch_size * self.n_points_per_layer, -1) else: assert len(f_in_np.shape) == 2 f_in2save = f_in_np.copy() @@ -487,6 +502,8 @@ def _regenerate_input_feature(self): if len(f_in_np.shape) == 4: # conv if self.prunable_idx.index(idx) == 0: # first conv f_in2save = None + elif m_list[idx].weight.size(3) > 1: # NxN conv (N>1) + f_in2save = self.subfunc_generate_input_features(m_list, idx, i_b, f_in_np) else: randx = self.layer_info_dict[idx][(i_b, 'randx')] randy = self.layer_info_dict[idx][(i_b, 'randy')] @@ -501,6 +518,27 @@ def _regenerate_input_feature(self): self.layer_info_dict[idx]['input_feat'] = np.vstack( (self.layer_info_dict[idx]['input_feat'], f_in2save)) + def subfunc_generate_input_features(self, m_list, idx, i_b, f_in_np): + # generate input features for NxN conv, where N>1. although it should work for 1x1 conv as well. + _, _, s1, s2 = m_list[idx].weight.shape # kernel shape + c2, c1, t1, t2 = f_in_np.shape # feature map shape + randx = self.layer_info_dict[idx][(i_b, 'randx')] + randy = self.layer_info_dict[idx][(i_b, 'randy')] + ph = int(s1/2) # horizontal padding + pw = int(s2/2) # vertical padding + f_in_np_pad = np.pad(f_in_np,((0,0),(0,0),(ph,ph),(pw,pw))) + f_in2save = 0 # initialized + for kk in range(self.n_points_per_layer): + j1 = randx[kk] + j2 = randy[kk] + x = f_in_np_pad[:, :, j1:j1+s1, j2:j2+s2] + x = x.reshape([x.shape[0],-1]) + try : + f_in2save = np.append(f_in2save, x, axis=0) + except : + f_in2save = x + return f_in2save + def _build_state_embedding(self): # build the static part of the state embedding layer_embedding = [] diff --git a/scripts/search_vgg16_0.5flops.sh b/scripts/search_vgg16_0.5flops.sh new file mode 100644 index 0000000..74f0524 --- /dev/null +++ b/scripts/search_vgg16_0.5flops.sh @@ -0,0 +1 @@ +python amc_search.py --job=train --model=vgg16 --ckpt_path=path_to_checkpoint --dataset=imagenet --data_root=path_to_data_root --preserve_ratio=0.5 --lbound=0.2 --rbound=1 --reward=acc_reward --n_calibration_batches=60 --seed=2018 \ No newline at end of file