|
16 | 16 | namespace nvinfer1 {
|
17 | 17 |
|
18 | 18 | __global__ void batched_nms_kernel(
|
19 |
| - const int num_per_thread, const float threshold, const int num_detections, |
| 19 | + const float threshold, const int num_detections, |
20 | 20 | const int *indices, float *scores, const float *classes, const float4 *boxes) {
|
21 | 21 |
|
22 | 22 | // Go through detections by descending score
|
23 | 23 | for (int m = 0; m < num_detections; m++) {
|
24 |
| - for (int n = 0; n < num_per_thread; n++) { |
25 |
| - int i = threadIdx.x * num_per_thread + n; |
26 |
| - if (i < num_detections && m < i && scores[m] > 0.0f) { |
27 |
| - int idx = indices[i]; |
28 |
| - int max_idx = indices[m]; |
29 |
| - int icls = classes[idx]; |
30 |
| - int mcls = classes[max_idx]; |
31 |
| - if (mcls == icls) { |
32 |
| - float4 ibox = boxes[idx]; |
33 |
| - float4 mbox = boxes[max_idx]; |
34 |
| - float x1 = max(ibox.x, mbox.x); |
35 |
| - float y1 = max(ibox.y, mbox.y); |
36 |
| - float x2 = min(ibox.z, mbox.z); |
37 |
| - float y2 = min(ibox.w, mbox.w); |
38 |
| - float w = max(0.0f, x2 - x1); |
39 |
| - float h = max(0.0f, y2 - y1); |
40 |
| - float iarea = (ibox.z - ibox.x) * (ibox.w - ibox.y); |
41 |
| - float marea = (mbox.z - mbox.x) * (mbox.w - mbox.y); |
42 |
| - float inter = w * h; |
43 |
| - float overlap = inter / (iarea + marea - inter); |
44 |
| - if (overlap > threshold) { |
45 |
| - scores[i] = 0.0f; |
46 |
| - } |
| 24 | + int i = blockIdx.x * blockDim.x + threadIdx.x; |
| 25 | + if (i < num_detections && m < i && scores[m] > 0.0f) { |
| 26 | + int idx = indices[i]; |
| 27 | + int max_idx = indices[m]; |
| 28 | + int icls = classes[idx]; |
| 29 | + int mcls = classes[max_idx]; |
| 30 | + if (mcls == icls) { |
| 31 | + float4 ibox = boxes[idx]; |
| 32 | + float4 mbox = boxes[max_idx]; |
| 33 | + float x1 = max(ibox.x, mbox.x); |
| 34 | + float y1 = max(ibox.y, mbox.y); |
| 35 | + float x2 = min(ibox.z, mbox.z); |
| 36 | + float y2 = min(ibox.w, mbox.w); |
| 37 | + float w = max(0.0f, x2 - x1); |
| 38 | + float h = max(0.0f, y2 - y1); |
| 39 | + float iarea = (ibox.z - ibox.x) * (ibox.w - ibox.y); |
| 40 | + float marea = (mbox.z - mbox.x) * (mbox.w - mbox.y); |
| 41 | + float inter = w * h; |
| 42 | + float overlap = inter / (iarea + marea - inter); |
| 43 | + if (overlap > threshold) { |
| 44 | + scores[i] = 0.0f; |
47 | 45 | }
|
48 | 46 | }
|
49 | 47 | }
|
@@ -104,7 +102,7 @@ int batchedNms(int batch_size,
|
104 | 102 | // TODO: different device has differnet max threads
|
105 | 103 | const int max_threads = 1024;
|
106 | 104 | int num_per_thread = ceil(static_cast<float>(num_detections) / max_threads);
|
107 |
| - batched_nms_kernel << <1, max_threads, 0, stream >> > (num_per_thread, nms_thresh, num_detections, |
| 105 | + batched_nms_kernel << <num_per_thread, max_threads, 0, stream >> > (nms_thresh, num_detections, |
108 | 106 | indices_sorted, scores_sorted, in_classes, in_boxes);
|
109 | 107 |
|
110 | 108 | // Re-sort with updated scores
|
|
0 commit comments