@@ -44,7 +44,7 @@ namespace trtx {
44
44
builder -> destroy ();
45
45
46
46
// write serialized engine to file
47
- std::ofstream trtFile (mParams .trtEngineFile );
47
+ std::ofstream trtFile (mParams .trtEngineFile , std::ios::binary );
48
48
if (!trtFile){
49
49
std::cerr << " Unable to open engine file." << std::endl;
50
50
return false ;
@@ -193,26 +193,26 @@ namespace trtx {
193
193
// Engine requires exactly IEngine::getNbBindings() number of buffers.
194
194
assert (mEngine -> getNbBindings () == 2 );
195
195
void * buffers[2 ];
196
-
196
+
197
197
// In order to bind the buffers, we need to know the names of the input and output tensors.
198
198
// Note that indices are guaranteed to be less than IEngine::getNbBindings()
199
199
const int inputIndex = mEngine ->getBindingIndex (mParams .inputTensorName );
200
200
const int outputIndex = mEngine ->getBindingIndex (mParams .outputTensorName );
201
-
201
+
202
202
// Create GPU buffers on device
203
203
CUDA_CHECK (cudaMalloc (&buffers[inputIndex], batchSize * 3 * mParams .inputH * mParams .inputW * sizeof (float )));
204
204
CUDA_CHECK (cudaMalloc (&buffers[outputIndex], batchSize * 1000 * sizeof (float )));
205
-
205
+
206
206
// Create stream
207
207
cudaStream_t stream;
208
208
CUDA_CHECK (cudaStreamCreate (&stream));
209
-
209
+
210
210
// DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host
211
211
CUDA_CHECK (cudaMemcpyAsync (buffers[inputIndex], input, batchSize * 3 * mParams .inputH * mParams .inputW * sizeof (float ), cudaMemcpyHostToDevice, stream));
212
212
mContext ->enqueue (batchSize, buffers, stream, nullptr );
213
213
CUDA_CHECK (cudaMemcpyAsync (output, buffers[outputIndex], batchSize * 1000 * sizeof (float ), cudaMemcpyDeviceToHost, stream));
214
214
cudaStreamSynchronize (stream);
215
-
215
+
216
216
// Release stream and buffers
217
217
cudaStreamDestroy (stream);
218
218
CUDA_CHECK (cudaFree (buffers[inputIndex]));
@@ -232,4 +232,5 @@ namespace trtx {
232
232
233
233
return true ;
234
234
}
235
- }
235
+ }
236
+
0 commit comments