Skip to content

Commit

Permalink
add some example
Browse files Browse the repository at this point in the history
  • Loading branch information
FindDefinition committed Feb 2, 2023
1 parent 2309ebe commit 004effb
Show file tree
Hide file tree
Showing 3 changed files with 353 additions and 129 deletions.
223 changes: 223 additions & 0 deletions example/mnist/custom_fx2trt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This example shows how to write custom fx2trt like tool to convert
pytorch model to tensorrt.
"""

from __future__ import print_function

import argparse
import contextlib
import copy
from typing import Dict, Optional

import torch
import torch.ao.quantization
import torch.ao.quantization.quantize_fx as qfx
import torch.cuda.amp
import torch.fx
import torch.nn as nn
import torch.optim as optim

from torch.fx import Tracer
import tensorrt as trt

from spconv.pytorch.quantization.interpreter import NetworkInterpreter, register_node_handler, register_method_handler
from spconv.pytorch.cppcore import torch_tensor_to_tv
import numpy as np
import spconv.constants as spconvc
import torch.nn.functional as F

def _simple_repr(x):
return f"Tensor[{x.shape}|{x.dtype}]"
# add verbose for ITensor
trt.ITensor.__repr__ = _simple_repr

class NetDense(nn.Module):
def __init__(self):
super(NetDense, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.conv_pool = nn.Conv2d(64, 64, 2, 2)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = self.conv_pool(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
if self.training:
x = F.log_softmax(x, dim=1)
return x

def _activation(net, x, act_type, alpha=None, beta=None, name=None):
layer = net.add_activation(x, act_type)
if alpha is not None:
layer.alpha = alpha
if beta is not None:
layer.beta = beta
output = layer.get_output(0)
if name is not None:
output.name = name
layer.name = name
return output

def _trt_reshape(net, inp, shape, name):
layer = net.add_shuffle(inp)
layer.reshape_dims = shape
output = layer.get_output(0)
layer.name = name
output.name = name
return output

# add module handler
@register_node_handler(nn.Conv2d)
def _conv2d(net, target: nn.Conv2d, args, kwargs, name: str):
x = args[0]
bias = target.bias
if target.bias is None:
bias = None
else:
bias = target.bias.detach().cpu().numpy()
weight = target.weight.detach().cpu().numpy()

O, I_groups, *ksize = weight.shape
I = I_groups * target.groups
stride = target.stride
padding = target.padding
dilation = target.dilation
weight_qdq = None
if not isinstance(weight, np.ndarray):
weight_qdq = weight
weight = trt.Weights()
else:
weight = trt.Weights(weight)
if bias is None:
bias = trt.Weights()
else:
bias = trt.Weights(bias)
layer = net.add_convolution_nd(x, O, tuple(ksize), weight, bias)
if weight_qdq is not None:
# in explicit quantization, we need this
layer.set_input(1, weight_qdq)
layer.stride_nd = tuple(stride)
layer.padding_nd = tuple(padding)
layer.dilation_nd = tuple(dilation)
layer.num_groups = target.groups
output = layer.get_output(0)
output.name = name
layer.name = name
return output

@register_node_handler(F.relu)
def _relu(net, target: nn.Conv2d, args, kwargs, name: str):
return _activation(net, args[0], trt.ActivationType.RELU, name=name)


@register_node_handler(nn.Dropout)
@register_node_handler(nn.Dropout1d)
@register_node_handler(nn.Dropout2d)
@register_node_handler(nn.Dropout3d)
def _identity_single(net, target, args, kwargs, name: str):
return args[0]

@register_node_handler(torch.flatten)
def _flatten(net, target, args, kwargs, name: str):
start_dim = args[1]
x = args[0]
return _trt_reshape(net, x, [*x.shape[:start_dim], int(np.prod(x.shape[start_dim:]))], name)

def _dot(net, x, y, transpose_x=False, transpose_y=False, name=None):
mode_x = trt.MatrixOperation.NONE
if transpose_x:
mode_x = trt.MatrixOperation.TRANSPOSE
mode_y = trt.MatrixOperation.NONE
if transpose_y:
mode_y = trt.MatrixOperation.TRANSPOSE
layer = net.add_matrix_multiply(x, mode_x, y, mode_y)

output = layer.get_output(0)
assert name is not None

output.name = name
layer.name = name
return output

def _constant(net, array, name):
array = np.array(array)
layer = net.add_constant(array.shape, trt.Weights(array.reshape(-1)))
out = layer.get_output(0)
layer.name = name
out.name = name
return out

@register_node_handler(nn.Linear)
def _linear(net, target: nn.Linear, args, kwargs, name: str):
x = args[0]
bias = target.bias
if target.bias is None:
bias = None
else:
bias = target.bias.detach().cpu().numpy()
weight = target.weight.detach().cpu().numpy()
weight_trt = _constant(net, weight, name + "/weight")
res = _dot(net, x, weight_trt, transpose_y=True, name=name)
if bias is not None:
bias_trt = _constant(net, bias.reshape(1, -1), name + "/bias")
layer = net.add_elementwise(res, bias_trt, trt.ElementWiseOperation.SUM)
res = layer.get_output(0)
add_name = name + "/add"
res.name = add_name
layer.name = add_name
return res

def main():
model = NetDense()
model = model.eval()
tc = Tracer()
graph_trace = tc.trace(model)
gm = torch.fx.GraphModule(tc.root, graph_trace)
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
# try:
# import pycuda.autoprimaryctx
# except ModuleNotFoundError:
# import pycuda.autoinit
with trt.Runtime(TRT_LOGGER) as rt:
with trt.Builder(TRT_LOGGER) as builder:
with builder.create_network(True) as network:
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30
input_tensor = network.add_input(name="inp", dtype=trt.float32, shape=[1, 1, 28, 28])
interp = NetworkInterpreter(network, gm, [input_tensor], verbose=True)
# get converted outputs from interp
outputs = interp.run()
network.mark_output(tensor=outputs[0])
plan = builder.build_serialized_network(network, config)
engine = rt.deserialize_cuda_engine(plan)

if __name__ == '__main__':
main()
129 changes: 129 additions & 0 deletions spconv/pytorch/interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import Any, Dict, List, Optional, Set, Type

import torch
import torch.fx

REGISTERED_NODE_HANDLERS: Dict[Any, Any] = {}


def register_node_handler(*names):

def wrap_func(handler):
global REGISTERED_NODE_HANDLERS
for n in names:
REGISTERED_NODE_HANDLERS[n] = handler

def new_handler(*args, **kwargs):
return handler(*args, **kwargs)

return new_handler

return wrap_func


def register_method_handler(name: str, tensor_classes):
if not isinstance(tensor_classes, (list, tuple)):
tensor_classes = [tensor_classes]

def wrap_func(handler):
global REGISTERED_NODE_HANDLERS
for tcls in tensor_classes:
REGISTERED_NODE_HANDLERS[(tcls, name)] = handler

def new_handler(*args, **kwargs):
return handler(*args, **kwargs)

return new_handler

return wrap_func


def get_node_handler(name):
global REGISTERED_NODE_HANDLERS
msg = "missing handler " + str(name)
msg += ", available handlers: {}".format(
list(REGISTERED_NODE_HANDLERS.keys()))
assert name in REGISTERED_NODE_HANDLERS, msg
return REGISTERED_NODE_HANDLERS[name]


class NetworkInterpreter(torch.fx.Interpreter):

def __init__(self,
network_ctx,
module: torch.fx.GraphModule,
inputs: List[Any],
verbose: bool = False):
super().__init__(module)
self.network_ctx = network_ctx
self._inputs = inputs
self._outputs = None
self._cur_node_name: Optional[str] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
self._verbose = verbose

def run(self):
super().run(*self._inputs)
assert self._outputs is not None
return self._outputs

def run_node(self, n):
self._cur_node_name = str(n)
return super().run_node(n)

def call_module(self, target, args, kwargs):
assert isinstance(target, str)
submod = self.fetch_attr(target)
submod_type = getattr(submod, "_base_class_origin", type(submod))
type_str = submod_type.__qualname__
type_str_parts = type_str.split(".")
msg = f"[Module.{type_str_parts[-1]}]{target}({args}|{kwargs}) => "

try:
converter = get_node_handler(submod_type)
res = converter(self.network_ctx, submod, args, kwargs,
self._cur_node_name)
msg += f"{res}"
if self._verbose:
print(msg)
return res
except Exception as e:
if self._verbose:
print(msg)
raise e

def call_function(self, target, args, kwargs):
msg = f"[Func]{target}({args}|{kwargs}) => "
try:
converter = get_node_handler(target)
res = converter(self.network_ctx, target, args, kwargs,
self._cur_node_name)
msg += f"{res}"
if self._verbose:
print(msg)
return res
except Exception as e:
if self._verbose:
print(msg)
raise e

def call_method(self, target, args, kwargs):
msg = f"[Method]{target}({args}|{kwargs}) => "
assert isinstance(target, str)
try:
key = (type(args[0]), target)
converter = get_node_handler(key)
res = converter(self.network_ctx, target, args, kwargs,
self._cur_node_name)
msg += f"{res}"
if self._verbose:
print(msg)
return res
except Exception as e:
if self._verbose:
print(msg)
raise e

def output(self, target, args, kwargs):
self._outputs = args
Loading

0 comments on commit 004effb

Please sign in to comment.