-
Notifications
You must be signed in to change notification settings - Fork 316
/
Copy pathtrt_dep.cu
308 lines (235 loc) · 10.1 KB
/
trt_dep.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
#include <iostream>
#include <string>
#include <fstream>
#include <vector>
#include <array>
#include <unordered_map>
#include <sstream>
#include <chrono>
#include <iterator>
#include "trt_dep.hpp"
#include "argmax_plugin.h"
#include "batch_stream.hpp"
#include "entropy_calibrator.hpp"
using nvinfer1::IHostMemory;
using nvinfer1::IBuilder;
using nvinfer1::INetworkDefinition;
using nvinfer1::ICudaEngine;
using nvinfer1::IInt8Calibrator;
using nvinfer1::IBuilderConfig;
using nvinfer1::IRuntime;
using nvinfer1::IExecutionContext;
using nvinfer1::ILogger;
using nvinfer1::Dims;
using nvinfer1::Dims4;
using nvinfer1::OptProfileSelector;
using Severity = nvinfer1::ILogger::Severity;
using std::string;
using std::ios;
using std::ofstream;
using std::ifstream;
using std::vector;
using std::cout;
using std::endl;
using std::array;
Logger gLogger;
void CHECK(bool condition, string msg) {
if (!condition) {
cout << msg << endl;;
std::terminate();
}
}
void SemanticSegmentTrt::register_plugins() {
// this should be before onnx parser
plugin_creator.reset(new ArgMaxPluginCreator{});
plugin_creator->setPluginNamespace("");
bool status = getPluginRegistry()->registerCreator(*plugin_creator.get(), "");
CHECK(status, "failed to register plugin");
}
void SemanticSegmentTrt::parse_to_engine(string onnx_pth,
string quant, string data_root, string data_file) {
auto builder = TrtUnqPtr<IBuilder>(nvinfer1::createInferBuilder(gLogger));
CHECK(static_cast<bool>(builder), "create builder failed");
auto network = TrtUnqPtr<INetworkDefinition>(builder->createNetworkV2(0));
CHECK(static_cast<bool>(network), "create network failed");
auto parser = TrtUnqPtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, gLogger));
CHECK(static_cast<bool>(parser), "create parser failed");
int verbosity = (int)nvinfer1::ILogger::Severity::kWARNING;
bool success = parser->parseFromFile(onnx_pth.c_str(), verbosity);
CHECK(success, "parse onnx file failed");
if (network->getNbInputs() != 1) {
cout << "expect model to have only one input, but this model has "
<< network->getNbInputs() << endl;
std::terminate();
}
auto input = network->getInput(0);
auto output = network->getOutput(0);
input_name = input->getName();
output_name = output->getName();
auto config = TrtUnqPtr<IBuilderConfig>(builder->createBuilderConfig());
CHECK(static_cast<bool>(config), "create builder config failed");
config->setProfileStream(*stream);
auto profile = builder->createOptimizationProfile();
Dims in_dims = network->getInput(0)->getDimensions();
int32_t C = in_dims.d[1], H = in_dims.d[2], W = in_dims.d[3];
Dims dmin = Dims4{1, C, H, W};
Dims dopt = Dims4{opt_bsize, C, H, W};
Dims dmax = Dims4{32, C, H, W};
profile->setDimensions(input->getName(), OptProfileSelector::kMIN, dmin);
profile->setDimensions(input->getName(), OptProfileSelector::kOPT, dopt);
profile->setDimensions(input->getName(), OptProfileSelector::kMAX, dmax);
config->addOptimizationProfile(profile);
config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, 1UL << 32);
if (quant == "fp16" or quant == "int8") { // fp16
if (builder->platformHasFastFp16() == false) {
cout << "fp16 is set, but platform does not support, so we ignore this\n";
} else {
config->setFlag(nvinfer1::BuilderFlag::kFP16);
}
}
if (quant == "bf16") { // bf16
config->setFlag(nvinfer1::BuilderFlag::kBF16);
}
if (quant == "fp8") { // fp8
config->setFlag(nvinfer1::BuilderFlag::kFP8);
}
std::unique_ptr<IInt8Calibrator> calibrator;
if (quant == "int8") { // int8
if (builder->platformHasFastInt8() == false) {
cout << "int8 is set, but platform does not support, so we ignore this\n";
} else {
int batchsize = 32;
int n_cal_batches = -1;
string cal_table_name = "calibrate_int8";
Dims indim = network->getInput(0)->getDimensions();
BatchStream calibrationStream(
batchsize, n_cal_batches, indim,
data_root, data_file);
config->setFlag(nvinfer1::BuilderFlag::kINT8);
calibrator.reset(new Int8EntropyCalibrator2<BatchStream>(
calibrationStream, 0, cal_table_name.c_str(), input_name.c_str(), false));
config->setInt8Calibrator(calibrator.get());
}
}
// output->setType(nvinfer1::DataType::kINT32);
// output->setType(nvinfer1::DataType::kFLOAT);
cout << "start to build \n";
auto plan = TrtUnqPtr<IHostMemory>(builder->buildSerializedNetwork(*network, *config));
CHECK(static_cast<bool>(plan), "build serialized engine failed");
runtime.reset(nvinfer1::createInferRuntime(gLogger));
CHECK(static_cast<bool>(runtime), "create runtime failed");
engine.reset(runtime->deserializeCudaEngine(plan->data(), plan->size()));
CHECK(static_cast<bool>(engine), "deserialize engine failed");
cout << "done build engine \n";
}
void SemanticSegmentTrt::set_opt_batch_size(int bs) {
CHECK(bs > 0 and bs < 33, "batch size should be less than 32");
opt_bsize = bs;
}
void SemanticSegmentTrt::serialize(string save_path) {
auto trt_stream = TrtUnqPtr<IHostMemory>(engine->serialize());
CHECK(static_cast<bool>(trt_stream), "serialize engine failed");
ofstream ofile(save_path, ios::out | ios::binary);
ofile.write((const char*)trt_stream->data(), trt_stream->size());
ofile.close();
}
void SemanticSegmentTrt::deserialize(string serpth) {
ifstream ifile(serpth, ios::in | ios::binary);
CHECK(static_cast<bool>(ifile), "read serialized file failed");
ifile.seekg(0, ios::end);
const int mdsize = ifile.tellg();
ifile.clear();
ifile.seekg(0, ios::beg);
vector<char> buf(mdsize);
ifile.read(&buf[0], mdsize);
ifile.close();
cout << "model size: " << mdsize << endl;
runtime.reset(nvinfer1::createInferRuntime(gLogger));
engine.reset(runtime->deserializeCudaEngine((void*)&buf[0], mdsize));
input_name = engine->getIOTensorName(0);
output_name = engine->getIOTensorName(1);
}
vector<int32_t> SemanticSegmentTrt::inference(vector<float>& data) {
Dims in_dims = engine->getTensorShape(input_name.c_str());
Dims out_dims = engine->getTensorShape(output_name.c_str());
const int64_t batchsize{1}, H{out_dims.d[1]}, W{out_dims.d[2]};
const int64_t in_size{static_cast<int64_t>(data.size())};
const int64_t out_size{batchsize * H * W};
Dims4 in_shape(batchsize, in_dims.d[1], in_dims.d[2], in_dims.d[3]);
vector<void*> buffs(2, nullptr);
vector<int32_t> res(out_size);
cudaError_t state;
state = cudaMalloc(&buffs[0], in_size * sizeof(float));
CHECK(state == cudaSuccess, "allocate memory failed");
state = cudaMalloc(&buffs[1], out_size * sizeof(int32_t));
CHECK(state == cudaSuccess, "allocate memory failed");
state = cudaMemcpyAsync(
buffs[0], &data[0], in_size * sizeof(float),
cudaMemcpyHostToDevice, *stream);
CHECK(state == cudaSuccess, "transmit to device failed");
auto context = TrtUnqPtr<IExecutionContext>(engine->createExecutionContext());
CHECK(static_cast<bool>(context), "create execution context failed");
// Dynamic shape require this setInputShape
bool success = context->setInputShape(input_name.c_str(), in_shape);
CHECK(success, "set input shape failed");
context->setInputTensorAddress(input_name.c_str(), buffs[0]);
context->setOutputTensorAddress(output_name.c_str(), buffs[1]);
context->enqueueV3(*stream);
state = cudaMemcpyAsync(
&res[0], buffs[1], out_size * sizeof(int32_t),
cudaMemcpyDeviceToHost, *stream);
CHECK(state == cudaSuccess, "transmit back to host failed");
cudaStreamSynchronize(*stream);
for (auto buf : buffs) {
cudaFree(buf);
}
return res;
}
void SemanticSegmentTrt::test_speed_fps() {
Dims in_dims = engine->getTensorShape(input_name.c_str());
Dims out_dims = engine->getTensorShape(output_name.c_str());
const int64_t batchsize{opt_bsize};
const int64_t oH{out_dims.d[1]}, oW{out_dims.d[2]};
const int64_t iH{in_dims.d[2]}, iW{in_dims.d[3]};
const int64_t in_size{batchsize * 3 * iH * iW};
const int64_t out_size{batchsize * oH * oW};
Dims4 in_shape(batchsize, in_dims.d[1], in_dims.d[2], in_dims.d[3]);
vector<void*> buffs(2, nullptr);
cudaError_t state;
state = cudaMalloc(&buffs[0], in_size * sizeof(float));
CHECK(state == cudaSuccess, "allocate memory failed");
state = cudaMalloc(&buffs[1], out_size * sizeof(int32_t));
CHECK(state == cudaSuccess, "allocate memory failed");
auto context = TrtUnqPtr<IExecutionContext>(engine->createExecutionContext());
CHECK(static_cast<bool>(context), "create execution context failed");
bool success = context->setInputShape(input_name.c_str(), in_shape);
CHECK(success, "set input shape failed");
cout << "\ntest with cropsize of (" << iH << ", " << iW << "), "
<< "and batch size of " << batchsize << " ...\n";
context->executeV2(buffs.data()); // run one batch ahead
auto start = std::chrono::steady_clock::now();
const int n_loops{2000};
for (int i{0}; i < n_loops; ++i) {
context->executeV2(buffs.data());
}
auto end = std::chrono::steady_clock::now();
double duration = std::chrono::duration<double, std::milli>(end - start).count();
duration /= 1000.;
int n_frames = n_loops * batchsize;
cout << "running " << n_loops << " times, use time: "
<< duration << "s" << endl;
cout << "fps is: " << static_cast<double>(n_frames) / duration << endl;
for (auto buf : buffs) {
cudaFree(buf);
}
}
vector<int> SemanticSegmentTrt::get_input_shape() {
Dims i_dims = engine->getTensorShape(input_name.c_str());
vector<int> res(i_dims.d, i_dims.d + i_dims.nbDims);
return res;
}
vector<int> SemanticSegmentTrt::get_output_shape() {
Dims o_dims = engine->getTensorShape(output_name.c_str());
vector<int> res(o_dims.d, o_dims.d + o_dims.nbDims);
return res;
}