Skip to content
This repository was archived by the owner on Jun 15, 2021. It is now read-only.

Commit bd9d575

Browse files
authored
PyTorch Support 🚀 🚀 🚀 (#16)
* PyTorch Support 🚀 🚀 🚀 * Fix CI
1 parent 8071973 commit bd9d575

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+835
-101
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
tests_results/
12

23
# Created by https://www.gitignore.io/api/python,pycharm
34
# Edit at https://www.gitignore.io/?templates=python,pycharm

.travis.yml

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
language: python
22

33
python:
4-
- 3.6
4+
- "3.6"
55

66
cache:
77
- pip
88

99
install:
1010
- pip install -qU pip setuptools
11-
- pip install -q pytest==3.6 pytest-cov coverage
11+
- pip install -q pytest==3.6 pytest-cov coverage matplotlib
1212
- pip install -q chainer networkx chainer_computational_cost
1313
- pip install -q chainercv
14+
- pip install -q https://download.pytorch.org/whl/cpu/torch-1.1.0-cp36-cp36m-linux_x86_64.whl
15+
- pip install -q https://download.pytorch.org/whl/cpu/torchvision-0.3.0-cp36-cp36m-linux_x86_64.whl
1416
- pip install -q -e .
1517

1618
script:
17-
- pytest -vv tests
19+
- pytest -vv --cov=chainerpruner --cov-report=html --color=auto -s --basetemp tests_results tests

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# ChainerPruner
22

33
[![Build Status](https://travis-ci.org/DeNA/ChainerPruner.svg?branch=master)](https://travis-ci.org/DeNA/ChainerPruner)
4+
![License](https://img.shields.io/badge/license-MIT-brightgreen.svg)
45

56
ChainerPruner: Channel Pruning framework for [Chainer](https://github.com/chainer/chainer)
67

chainerpruner/__init__.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,33 @@
55
__author_email__ = '[email protected]'
66
__url__ = 'https://github.com/DeNA/ChainerPruner'
77

8+
try:
9+
import chainer
10+
11+
avalable_chainer = True
12+
except ImportError:
13+
avalable_chainer = False
14+
15+
try:
16+
import torch
17+
18+
avalable_pytorch = True
19+
except ImportError:
20+
avalable_pytorch = False
21+
22+
import os
23+
from chainerpruner import utils
824
from chainerpruner import masks
925
from chainerpruner import pruning
1026
from chainerpruner import rebuild
1127
from chainerpruner import serializers
12-
from chainerpruner import utils
13-
from chainerpruner import trace
1428

15-
from chainerpruner.graph import Graph
16-
from chainerpruner.node import Node
1729
from chainerpruner.pruner import Pruner
30+
from chainerpruner.graph import Graph
31+
32+
disable_patch = os.getenv('CHAINERPRUNER_DISABLE_PYTORCH_LOAD_PATCH', False)
33+
34+
if avalable_pytorch and not disable_patch:
35+
from chainerpruner.serializers.pytorch import enable_custom_load_state_dict
36+
37+
enable_custom_load_state_dict()

chainerpruner/graph/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from chainerpruner.graph.graph import Graph

chainerpruner/graph.py chainerpruner/graph/chainer/graph.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import networkx as nx
66

77
import chainer
8-
from chainerpruner.trace import TraceLinkHook, TraceFunctionHook
9-
from chainerpruner.node import Node
8+
from chainerpruner.graph.chainer.trace import TraceLinkHook, TraceFunctionHook
9+
from chainerpruner.graph.node import Node
1010

1111

12-
class Graph():
12+
class ChainerGraph():
1313
"""Computation Graph Parser
1414
1515
Chainerの計算グラフはLinkとFunctionから構成される

chainerpruner/trace.py chainerpruner/graph/chainer/trace.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import weakref
66
import chainer
77

8-
from chainerpruner.node import Node
8+
from chainerpruner.graph.node import Node
99

1010

1111
class TraceFunctionHook(chainer.FunctionHook):

chainerpruner/graph/graph.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) 2018 DeNA Co., Ltd.
2+
# Licensed under The MIT License [see LICENSE for details]
3+
4+
from chainerpruner import utils
5+
6+
7+
class Graph():
8+
"""Computation Graph Parser
9+
10+
Chainerの計算グラフはLinkとFunctionから構成される
11+
Channel Pruningする上では、現在の層をPruningしたときに
12+
後続する層のどこまでその影響があるかを知ることが重要になる
13+
ここではモデルをパースして、linksに各層とその影響する層の情報をList[Node]として格納している
14+
"""
15+
16+
def __init__(self, model, args):
17+
if utils.is_chainer_model(model):
18+
from chainerpruner.graph.chainer.graph import ChainerGraph
19+
graph = ChainerGraph(model, args)
20+
self.is_chainer = True
21+
self.is_pytorch = False
22+
else:
23+
from chainerpruner.graph.pytorch.graph import PyTorchGraph
24+
graph = PyTorchGraph(model, args)
25+
self.is_chainer = False
26+
self.is_pytorch = True
27+
28+
self.graph = graph.graph
29+
30+
def plot(self, options=None):
31+
import networkx as nx
32+
if not options:
33+
options = dict()
34+
nx.draw(self.graph, with_label=True, **options)

chainerpruner/node.py chainerpruner/graph/node.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
11
# Copyright (c) 2018 DeNA Co., Ltd.
22
# Licensed under The MIT License [see LICENSE for details]
33

4-
from collections import Sequence
5-
6-
import chainer
7-
84

95
class Node():
106
"""計算グラフのノード。channel pruningに必要な情報を抽出している。
117
128
入出力のテンソルのサイズや、ノードの型、chainerのLinkやFuntionのオブジェクト自体もアクセスできる
139
"""
1410

15-
def __init__(self, id_, type_, args, out, link=None, function=None):
11+
def __init__(self, id_, type_, args=None, out=None, link=None, function=None, module=None):
1612
"""
1713
1814
nextは、List[Node or List[Node]]
@@ -26,10 +22,10 @@ def __init__(self, id_, type_, args, out, link=None, function=None):
2622
function:
2723
"""
2824

29-
if not all([isinstance(a, chainer.variable.VariableNode) for a in args]):
30-
raise TypeError()
31-
if not all([isinstance(o, chainer.variable.VariableNode) for o in out]):
32-
raise TypeError()
25+
if not args:
26+
args = []
27+
if not out:
28+
out = []
3329

3430
self.name = id_
3531
self.id = id_
@@ -40,6 +36,7 @@ def __init__(self, id_, type_, args, out, link=None, function=None):
4036
self.output_shape = [o.shape if hasattr(o, 'shape') else None for o in out]
4137
self.link = link
4238
self.function = function
39+
# self.module = module
4340

4441
def __repr__(self):
4542
if self.link:

chainerpruner/graph/pytorch/graph.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) 2018 DeNA Co., Ltd.
2+
# Licensed under The MIT License [see LICENSE for details]
3+
4+
from typing import Sequence
5+
import networkx as nx
6+
7+
from chainerpruner.graph.pytorch.trace import ModuleTraceHook
8+
from chainerpruner.graph.node import Node
9+
10+
11+
class PyTorchGraph():
12+
"""Computation Graph Parser
13+
"""
14+
15+
def __init__(self, model, args):
16+
# convert to tuple of Variable
17+
if not isinstance(args, Sequence):
18+
args = [args]
19+
20+
# 計算グラフを構築しながら、hookを利用して各Moduleの情報を取得する
21+
hook = ModuleTraceHook(model)
22+
model(*args)
23+
24+
nodes = hook.graph # type: Sequence[Node]
25+
self.graph = self._traverse_connections(nodes)
26+
27+
def _traverse_connections(self, nodes):
28+
# Hookで得たlinkとfunctionのList[Node]をマージしてグラフにする
29+
30+
# node.nameをキーにしたGraph
31+
# node attrでnodeの実体へアクセス
32+
graph = nx.DiGraph()
33+
34+
for node in nodes:
35+
graph.add_node(node)
36+
for next_node in nodes:
37+
if set(node.output_id) & set(next_node.input_id):
38+
graph.add_edge(node, next_node)
39+
40+
return graph

chainerpruner/graph/pytorch/trace.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) 2018 DeNA Co., Ltd.
2+
# Licensed under The MIT License [see LICENSE for details]
3+
4+
from typing import Sequence
5+
6+
from chainerpruner.graph.node import Node
7+
8+
9+
class ModuleTraceHook():
10+
"""modelをトレースしてpruningに必要な計算グラフの情報を取得する
11+
12+
トレース後のグラフは`graph`属性として得られる
13+
14+
"""
15+
16+
def __init__(self, model):
17+
self.graph = []
18+
# グラフ内各moduleをuniqueに特定するための名前
19+
self._module_global_names = dict()
20+
self._regist(model)
21+
22+
def _regist(self, model):
23+
for name, module in model.named_modules():
24+
# nn.Sequentialやユーザー定義の複数のmoduleをもつmoduleではなく、
25+
# Conv2dなどの単一のmoduleのみ解析したいため
26+
n_modules = len(list(module.modules()))
27+
if n_modules == 1:
28+
# print('regist {}'.format(name))
29+
self._module_global_names[module] = name
30+
module.register_forward_hook(self._hook)
31+
32+
def _hook(self, module, input, output):
33+
"""
34+
35+
Args:
36+
module:
37+
input: tuple(Tensor)
38+
output: Tensor
39+
40+
Returns:
41+
42+
"""
43+
if not isinstance(input, Sequence):
44+
input = [input]
45+
if not isinstance(output, Sequence):
46+
output = [output]
47+
48+
name = self._module_global_names[module]
49+
50+
node = Node(id_=name,
51+
type_=type(module),
52+
args=None,
53+
out=None,
54+
# module=module,
55+
link=module,
56+
)
57+
58+
# TODO(tkat0) mv
59+
node.input_id = [i._cdata for i in input]
60+
node.input_shape = [tuple(i.shape) if hasattr(i, 'shape') else None for i in input]
61+
node.output_id = [o._cdata for o in output if o is not None]
62+
node.output_shape = [tuple(o.shape) if hasattr(o, 'shape') else None for o in output]
63+
64+
if set(node.input_id) == set(node.output_id):
65+
pass # skip in-place op (like ReLU)
66+
else:
67+
self.graph.append(node)

chainerpruner/mask.py

+47-6
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
import logging
55
import numpy as np
66

7-
from chainer import links as L
8-
97
from chainerpruner.rebuild import calc_pruning_connection
8+
from chainerpruner.utils import named_modules
109

1110
logger = logging.getLogger(__name__)
1211

@@ -17,10 +16,11 @@ class Mask():
1716

1817
def __init__(self, model, graph, target_layers, mask_layer=None):
1918
self.model = model
20-
self._model_dict = {name: link for name, link in model.namedlinks()}
19+
self._model_dict = {name: link for name, link in named_modules(model)}
2120
self.graph = graph
2221
self.target_layers = target_layers
2322
self.logger = logger
23+
self._is_chainer = graph.is_chainer
2424
self.pruning_connection_info = calc_pruning_connection(graph)
2525
self.masks = dict()
2626
self._mask_layer = mask_layer
@@ -31,6 +31,9 @@ def __init__(self, model, graph, target_layers, mask_layer=None):
3131
if mask_layer not in cand_mask_layer:
3232
raise AttributeError('mask_layer is expected which {}'.format(cand_mask_layer))
3333

34+
def is_chainer(self):
35+
return self._is_chainer
36+
3437
def get_filter_norm(self, mask):
3538
"""get mask for pruning
3639
@@ -58,11 +61,44 @@ def get_thresholds(self, name, mask):
5861
raise NotImplementedError()
5962

6063
def _get_mask(self, name):
64+
"""
65+
66+
Args:
67+
name:
68+
69+
Returns:
70+
(NDArray, ndarray): (conv-weight, mask-tensor)
71+
conv-weight: (oc, ic, k, k) kernel order
72+
mask-tensor: (oc, ic, k, k) or (oc, 1, 1, 1)
73+
74+
"""
75+
if self.is_chainer():
76+
return self._get_mask_chainer(name)
77+
else:
78+
return self._get_mask_pytorch(name)
79+
80+
def _get_mask_pytorch(self, name):
81+
from torch import nn
82+
conv = self._model_dict[name]
83+
if self._mask_layer is None:
84+
mask = conv.weight.data.clone()
85+
elif self._mask_layer == 'batchnorm':
86+
# propagate mask bn: conv-bn
87+
post_conv_bn_name = self.pruning_connection_info[name][0]
88+
bn = self._model_dict[post_conv_bn_name]
89+
if not isinstance(bn, nn.BatchNorm2d):
90+
raise ValueError('expected {}(Conv) -> {}(BatchNorm)'.format(name, post_conv_bn_name))
91+
mask = bn.weight.data.clone()
92+
mask = mask.reshape(-1, 1, 1, 1) # to mask conv weight (oc, ic, kh, kw)
93+
return conv.weight.data, mask
94+
95+
def _get_mask_chainer(self, name):
96+
from chainer import links as L
6197
conv = self._model_dict[name]
6298
if self._mask_layer is None:
6399
mask = conv.W.array.copy()
64100
elif self._mask_layer == 'batchnorm':
65-
# conv-bn
101+
# propagate mask bn: conv-bn
66102
post_conv_bn_name = self.pruning_connection_info[name][0]
67103
bn = self._model_dict[post_conv_bn_name]
68104
if not isinstance(bn, L.BatchNormalization):
@@ -80,7 +116,8 @@ def __call__(self):
80116

81117
# get mask vector
82118
target_weights = []
83-
for name, link in self.model.namedlinks(skipself=True):
119+
options = {'skipself': True} if self.is_chainer() else dict()
120+
for name, link in named_modules(self.model, **options):
84121

85122
self.logger.debug('name: %s', name)
86123

@@ -93,6 +130,7 @@ def __call__(self):
93130
out_channels = mask.shape[0]
94131
mask = self.get_filter_norm(mask)
95132
if mask.shape != (out_channels, 1, 1, 1):
133+
# expected (oc, ic, k, k) kernel order
96134
raise RuntimeError()
97135

98136
self.masks[name] = mask
@@ -109,7 +147,10 @@ def __call__(self):
109147

110148
# apply mask
111149
mask = mask_ >= threshold # 0: pruning, 1: non-pruning
112-
mask = mask.astype(np.float32)
150+
try:
151+
mask = mask.astype(np.float32)
152+
except AttributeError:
153+
mask = mask.type_as(target_weight)
113154

114155
info_ = {
115156
'name': name,

0 commit comments

Comments
 (0)