Skip to content

Commit ccdf0e6

Browse files
authored
update README (wang-xinyu#673)
add latency improve nms
1 parent a34527a commit ccdf0e6

File tree

3 files changed

+57
-50
lines changed

3 files changed

+57
-50
lines changed

rcnn/BatchedNms.cu

+23-25
Original file line numberDiff line numberDiff line change
@@ -16,34 +16,32 @@
1616
namespace nvinfer1 {
1717

1818
__global__ void batched_nms_kernel(
19-
const int num_per_thread, const float threshold, const int num_detections,
19+
const float threshold, const int num_detections,
2020
const int *indices, float *scores, const float *classes, const float4 *boxes) {
2121

2222
// Go through detections by descending score
2323
for (int m = 0; m < num_detections; m++) {
24-
for (int n = 0; n < num_per_thread; n++) {
25-
int i = threadIdx.x * num_per_thread + n;
26-
if (i < num_detections && m < i && scores[m] > 0.0f) {
27-
int idx = indices[i];
28-
int max_idx = indices[m];
29-
int icls = classes[idx];
30-
int mcls = classes[max_idx];
31-
if (mcls == icls) {
32-
float4 ibox = boxes[idx];
33-
float4 mbox = boxes[max_idx];
34-
float x1 = max(ibox.x, mbox.x);
35-
float y1 = max(ibox.y, mbox.y);
36-
float x2 = min(ibox.z, mbox.z);
37-
float y2 = min(ibox.w, mbox.w);
38-
float w = max(0.0f, x2 - x1);
39-
float h = max(0.0f, y2 - y1);
40-
float iarea = (ibox.z - ibox.x) * (ibox.w - ibox.y);
41-
float marea = (mbox.z - mbox.x) * (mbox.w - mbox.y);
42-
float inter = w * h;
43-
float overlap = inter / (iarea + marea - inter);
44-
if (overlap > threshold) {
45-
scores[i] = 0.0f;
46-
}
24+
int i = blockIdx.x * blockDim.x + threadIdx.x;
25+
if (i < num_detections && m < i && scores[m] > 0.0f) {
26+
int idx = indices[i];
27+
int max_idx = indices[m];
28+
int icls = classes[idx];
29+
int mcls = classes[max_idx];
30+
if (mcls == icls) {
31+
float4 ibox = boxes[idx];
32+
float4 mbox = boxes[max_idx];
33+
float x1 = max(ibox.x, mbox.x);
34+
float y1 = max(ibox.y, mbox.y);
35+
float x2 = min(ibox.z, mbox.z);
36+
float y2 = min(ibox.w, mbox.w);
37+
float w = max(0.0f, x2 - x1);
38+
float h = max(0.0f, y2 - y1);
39+
float iarea = (ibox.z - ibox.x) * (ibox.w - ibox.y);
40+
float marea = (mbox.z - mbox.x) * (mbox.w - mbox.y);
41+
float inter = w * h;
42+
float overlap = inter / (iarea + marea - inter);
43+
if (overlap > threshold) {
44+
scores[i] = 0.0f;
4745
}
4846
}
4947
}
@@ -104,7 +102,7 @@ int batchedNms(int batch_size,
104102
// TODO: different device has differnet max threads
105103
const int max_threads = 1024;
106104
int num_per_thread = ceil(static_cast<float>(num_detections) / max_threads);
107-
batched_nms_kernel << <1, max_threads, 0, stream >> > (num_per_thread, nms_thresh, num_detections,
105+
batched_nms_kernel << <num_per_thread, max_threads, 0, stream >> > (nms_thresh, num_detections,
108106
indices_sorted, scores_sorted, in_classes, in_boxes);
109107

110108
// Re-sort with updated scores

rcnn/README.md

+13-2
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ sudo ./rcnn -d faster.engine ../samples
101101
R101-faster: ./configs/COCO-Detection/faster_rcnn_R_101_C4_3x.yaml
102102
R50-mask: ./configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_1x.yaml
103103
R101-mask: ./configs/COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x.yaml
104-
3.set BACKBONE_RESNETTYPE = R50(R101) rcnn.cpp line 13
104+
3.set BACKBONE_RESNETTYPE = R50(R101) rcnn.cpp line 14
105105
4.set STRIDE_IN_1X1=true in backbone.hpp
106106
5.follow how to run
107107
```
@@ -130,7 +130,18 @@ sudo ./rcnn -d faster.engine ../samples
130130

131131
1. quantizationType:fp32,fp16,int8. see BuildRcnnModel(rcnn.cpp line 345) for detail.
132132

133-
2. the usage of int8 is same with [tensorrtx/yolov5](../yolov5/README.md), but it has no improvement comparing to fp16.
133+
2. the usage of int8 is same with [tensorrtx/yolov5](../yolov5/README.md).
134+
135+
## Latency
136+
137+
average cost of doInference(in rcnn.cpp) from second time with batch=1 under the ubuntu environment above, input size: 640(w)*480(h)
138+
139+
| | fp32 | fp16 | int8 |
140+
| ------------- | ----- | ---- | ---- |
141+
| Faster-R50C4 | 138ms | 36ms | 30ms |
142+
| Faster-R101C4 | 146ms | 38ms | 32ms |
143+
| Mask-R50C4 | 153ms | 44ms | 33ms |
144+
| Mask-R101C4 | 168ms | 45ms | 35ms |
134145

135146
## Plugins
136147

rcnn/RpnNms.cu

+21-23
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,29 @@
1616
namespace nvinfer1 {
1717

1818
__global__ void rpn_nms_kernel(
19-
const int num_per_thread, const float threshold, const int num_detections,
19+
const float threshold, const int num_detections,
2020
const int *indices, float *scores, const float4 *boxes) {
2121
// Go through detections by descending score
2222
for (int m = 0; m < num_detections; m++) {
23-
for (int n = 0; n < num_per_thread; n++) {
24-
int i = threadIdx.x * num_per_thread + n;
25-
if (i < num_detections && m < i && scores[m] > -FLT_MAX) {
26-
int idx = indices[i];
27-
int max_idx = indices[m];
28-
29-
float4 ibox = boxes[idx];
30-
float4 mbox = boxes[max_idx];
31-
float x1 = max(ibox.x, mbox.x);
32-
float y1 = max(ibox.y, mbox.y);
33-
float x2 = min(ibox.z, mbox.z);
34-
float y2 = min(ibox.w, mbox.w);
35-
float w = max(0.0f, x2 - x1);
36-
float h = max(0.0f, y2 - y1);
37-
float iarea = (ibox.z - ibox.x) * (ibox.w - ibox.y);
38-
float marea = (mbox.z - mbox.x) * (mbox.w - mbox.y);
39-
float inter = w * h;
40-
float overlap = inter / (iarea + marea - inter);
41-
if (overlap > threshold) {
42-
scores[i] = -FLT_MAX;
43-
}
23+
int i = blockIdx.x * blockDim.x + threadIdx.x;
24+
if (i < num_detections && m < i && scores[m] > -FLT_MAX) {
25+
int idx = indices[i];
26+
int max_idx = indices[m];
27+
28+
float4 ibox = boxes[idx];
29+
float4 mbox = boxes[max_idx];
30+
float x1 = max(ibox.x, mbox.x);
31+
float y1 = max(ibox.y, mbox.y);
32+
float x2 = min(ibox.z, mbox.z);
33+
float y2 = min(ibox.w, mbox.w);
34+
float w = max(0.0f, x2 - x1);
35+
float h = max(0.0f, y2 - y1);
36+
float iarea = (ibox.z - ibox.x) * (ibox.w - ibox.y);
37+
float marea = (mbox.z - mbox.x) * (mbox.w - mbox.y);
38+
float inter = w * h;
39+
float overlap = inter / (iarea + marea - inter);
40+
if (overlap > threshold) {
41+
scores[i] = -FLT_MAX;
4442
}
4543
}
4644

@@ -98,7 +96,7 @@ namespace nvinfer1 {
9896
// TODO: different device has differnet max threads
9997
const int max_threads = 1024;
10098
int num_per_thread = ceil(static_cast<float>(num_detections) / max_threads);
101-
rpn_nms_kernel << <1, max_threads, 0, stream >> > (num_per_thread, nms_thresh, num_detections,
99+
rpn_nms_kernel << <num_per_thread, max_threads, 0, stream >> > (nms_thresh, num_detections,
102100
indices_sorted, scores_sorted, in_boxes);
103101

104102
// Re-sort with updated scores

0 commit comments

Comments
 (0)