|
| 1 | + |
| 2 | +# 分类器训练 |
| 3 | + |
| 4 | +`R-CNN`在完成卷积模型的微调后,额外使用了线性`SVM`分类器,采用负样本挖掘方法进行训练,参考[Hard Negative Mining](https://blog.zhujian.life/posts/bc29003.html) |
| 5 | + |
| 6 | +## 线性SVM |
| 7 | + |
| 8 | +参考: |
| 9 | + |
| 10 | +[线性SVM分类器](https://blog.zhujian.life/posts/ebe205e.html) |
| 11 | + |
| 12 | +[线性SVM分类器-PyTorch实现](https://blog.zhujian.life/posts/4d25cbab.html) |
| 13 | + |
| 14 | +线性`SVM`分类器包含了线性回归+折页损失,其中自定`义PyTorch`的折页损失实现: |
| 15 | + |
| 16 | +``` |
| 17 | +def hinge_loss(outputs, labels): |
| 18 | + """ |
| 19 | + 折页损失计算 |
| 20 | + :param outputs: 大小为(N, num_classes) |
| 21 | + :param labels: 大小为(N) |
| 22 | + :return: 损失值 |
| 23 | + """ |
| 24 | + num_labels = len(labels) |
| 25 | + corrects = outputs[range(num_labels), labels].unsqueeze(0).T |
| 26 | +
|
| 27 | + # 最大间隔 |
| 28 | + margin = 1.0 |
| 29 | + margins = outputs - corrects + margin |
| 30 | + loss = torch.sum(torch.max(margins, 1)[0]) / len(labels) |
| 31 | +
|
| 32 | + # # 正则化强度 |
| 33 | + # reg = 1e-3 |
| 34 | + # loss += reg * torch.sum(weight ** 2) |
| 35 | +
|
| 36 | + return loss |
| 37 | +``` |
| 38 | + |
| 39 | +## 负样本挖掘 |
| 40 | + |
| 41 | +实现流程如下: |
| 42 | + |
| 43 | +1. 设置初始训练集,正负样本数比值为`1:1`(以正样本数目为基准) |
| 44 | +2. 每轮训练完成后,使用分类器对剩余负样本进行检测,如果检测为正,则加入到训练集中 |
| 45 | +3. 重新训练分类器,重复第二步,直到检测精度开始收敛 |
| 46 | + |
| 47 | +## 训练参数 |
| 48 | + |
| 49 | +1. 学习率:`1e-4` |
| 50 | +2. 动量:`0.9` |
| 51 | +3. 随步长衰减:每隔`4`轮衰减一次,参数因子`α=0.1` |
| 52 | +4. 迭代次数:`10` |
| 53 | +5. 批量处理:每次训练`128`个图像,其中`32`个正样本,`96`个负样本 |
| 54 | + |
| 55 | +## 训练结果 |
| 56 | + |
| 57 | +``` |
| 58 | +$ python linear_svm.py |
| 59 | +Epoch 0/9 |
| 60 | +---------- |
| 61 | +train - positive_num: 625 - negative_num: 625 - data size: 1152 |
| 62 | +train Loss: 1.1406 Acc: 0.6424 |
| 63 | +val - positive_num: 625 - negative_num: 321474 - data size: 322048 |
| 64 | +val Loss: 1.0560 Acc: 0.8080 |
| 65 | +remiam negative size: 365403, acc: 0.8821 |
| 66 | +Epoch 1/9 |
| 67 | +---------- |
| 68 | +train - positive_num: 625 - negative_num: 43397 - data size: 43904 |
| 69 | +train Loss: 1.0180 Acc: 0.9410 |
| 70 | +val - positive_num: 625 - negative_num: 321474 - data size: 322048 |
| 71 | +val Loss: 1.0426 Acc: 0.9232 |
| 72 | +remiam negative size: 365403, acc: 0.9965 |
| 73 | +Epoch 2/9 |
| 74 | +---------- |
| 75 | +train - positive_num: 625 - negative_num: 43606 - data size: 44160 |
| 76 | +train Loss: 1.0063 Acc: 0.9716 |
| 77 | +val - positive_num: 625 - negative_num: 321474 - data size: 322048 |
| 78 | +val Loss: 1.0414 Acc: 0.9241 |
| 79 | +remiam negative size: 365403, acc: 0.9972 |
| 80 | +Epoch 3/9 |
| 81 | +---------- |
| 82 | +train - positive_num: 625 - negative_num: 43731 - data size: 44288 |
| 83 | +train Loss: 1.0047 Acc: 0.9767 |
| 84 | +val - positive_num: 625 - negative_num: 321474 - data size: 322048 |
| 85 | +val Loss: 1.0429 Acc: 0.9234 |
| 86 | +remiam negative size: 365403, acc: 0.9980 |
| 87 | +Epoch 4/9 |
| 88 | +---------- |
| 89 | +train - positive_num: 625 - negative_num: 43773 - data size: 44288 |
| 90 | +train Loss: 1.0039 Acc: 0.9788 |
| 91 | +val - positive_num: 625 - negative_num: 321474 - data size: 322048 |
| 92 | +val Loss: 1.0421 Acc: 0.9240 |
| 93 | +remiam negative size: 365403, acc: 0.9980 |
| 94 | +Epoch 5/9 |
| 95 | +---------- |
| 96 | +train - positive_num: 625 - negative_num: 43795 - data size: 44416 |
| 97 | +train Loss: 1.0039 Acc: 0.9795 |
| 98 | +val - positive_num: 625 - negative_num: 321474 - data size: 322048 |
| 99 | +val Loss: 1.0427 Acc: 0.9234 |
| 100 | +remiam negative size: 365403, acc: 0.9982 |
| 101 | +Epoch 6/9 |
| 102 | +---------- |
| 103 | +train - positive_num: 625 - negative_num: 43801 - data size: 44416 |
| 104 | +train Loss: 1.0040 Acc: 0.9786 |
| 105 | +val - positive_num: 625 - negative_num: 321474 - data size: 322048 |
| 106 | +val Loss: 1.0428 Acc: 0.9242 |
| 107 | +remiam negative size: 365403, acc: 0.9981 |
| 108 | +Epoch 7/9 |
| 109 | +---------- |
| 110 | +train - positive_num: 625 - negative_num: 43808 - data size: 44416 |
| 111 | +train Loss: 1.0036 Acc: 0.9799 |
| 112 | +val - positive_num: 625 - negative_num: 321474 - data size: 322048 |
| 113 | +val Loss: 1.0429 Acc: 0.9242 |
| 114 | +remiam negative size: 365403, acc: 0.9981 |
| 115 | +Epoch 8/9 |
| 116 | +---------- |
| 117 | +train - positive_num: 625 - negative_num: 43814 - data size: 44416 |
| 118 | +train Loss: 1.0035 Acc: 0.9812 |
| 119 | +val - positive_num: 625 - negative_num: 321474 - data size: 322048 |
| 120 | +val Loss: 1.0425 Acc: 0.9242 |
| 121 | +remiam negative size: 365403, acc: 0.9981 |
| 122 | +Epoch 9/9 |
| 123 | +---------- |
| 124 | +train - positive_num: 625 - negative_num: 43817 - data size: 44416 |
| 125 | +train Loss: 1.0036 Acc: 0.9802 |
| 126 | +val - positive_num: 625 - negative_num: 321474 - data size: 322048 |
| 127 | +val Loss: 1.0424 Acc: 0.9246 |
| 128 | +remiam negative size: 365403, acc: 0.9981 |
| 129 | +Training complete in 55m 50s |
| 130 | +Best val Acc: 0.924645 |
| 131 | +``` |
0 commit comments