Skip to content

Commit 58e5e78

Browse files
committed
feat(paddlejs-model): add detect model
1 parent 84334f0 commit 58e5e78

17 files changed

+1949
-4
lines changed

e2e/dist/assets/detect.json

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
[
2+
[
3+
1,
4+
0.9999813437461853,
5+
0.3285195231437683,
6+
0.6718374490737915,
7+
0.4688636064529419,
8+
0.8186416625976562
9+
],
10+
[
11+
1,
12+
0.9999707937240601,
13+
0.2805149555206299,
14+
0.48382753133773804,
15+
0.42529064416885376,
16+
0.635986864566803
17+
],
18+
[
19+
1,
20+
0.9983498454093933,
21+
0.5017392039299011,
22+
0.5820830464363098,
23+
0.5703544020652771,
24+
0.8374423384666443
25+
],
26+
[
27+
1,
28+
0.9950063824653625,
29+
0.4869859516620636,
30+
0.23868629336357117,
31+
0.5620469450950623,
32+
0.4887562096118927
33+
],
34+
[
35+
2,
36+
0.9992570281028748,
37+
0.36442363262176514,
38+
0.4332561790943146,
39+
0.43267953395843506,
40+
0.5595301389694214
41+
]
42+
]

e2e/dist/assets/imgs/detect.jpeg

819 KB
Loading

e2e/dist/index.html

+4-1
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@
1818
height: 640px;
1919
}
2020
</style>
21+
2122
</head>
2223
<body>
2324
<img id="car" src="./assets/imgs/car.webp"/>
2425
<img id="banana" src="./assets/imgs/banana.jpeg"/>
2526
<img id="ok" src="./assets/imgs/ok.jpeg"/>
2627
<img id="human" src="./assets/imgs/human.jpg"/>
2728
<img id="seg" src="./assets/imgs/seg.png"/>
28-
<img id="ocr" src="./assets/imgs/ocr.jpg" />
29+
<img id="ocr" src="./assets/imgs/ocr.jpg"/>
30+
<img id="detect" src="./assets/imgs/detect.jpeg"/>
2931
<canvas id="back_canvas"></canvas>
3032
<canvas id="seg_canvas"></canvas>
3133
</body>
@@ -35,4 +37,5 @@
3537
<script src="./gesture_bundle.js"></script>
3638
<script src="./humanseg_bundle.js"></script>
3739
<script src="./ocr_bundle.js"></script>
40+
<script src="./detect_bundle.js"></script>
3841
</html>

e2e/server.ts

+4-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ const mobilenetWebpackConfig = require('../packages/paddlejs-models/mobilenet/we
88
const gestureWebpackConfig = require('../packages/paddlejs-models/gesture/webpack.prod');
99
const humansegWebpackConfig = require('../packages/paddlejs-models/humanseg/webpack.prod');
1010
const ocrWebpackConfig = require('../packages/paddlejs-models/ocr/webpack.prod');
11+
const detectWebpackConfig = require('../packages/paddlejs-models/detect/webpack.prod');
1112

1213
const DIST_DIR = path.join(__dirname, 'dist');
1314

@@ -33,9 +34,10 @@ const mobilenet = new ConfigInfo('mobilenet', mobilenetWebpackConfig, true);
3334
const gesture = new ConfigInfo('gesture', gestureWebpackConfig, true);
3435
const humanseg = new ConfigInfo('humanseg', humansegWebpackConfig, true);
3536
const ocr = new ConfigInfo('ocr', ocrWebpackConfig, true);
37+
const detect = new ConfigInfo('detect', detectWebpackConfig, true);
3638

3739
// edit webpack config
38-
[core, webgl, mobilenet, gesture, humanseg, ocr].forEach(instance => {
40+
[core, webgl, mobilenet, gesture, humanseg, ocr, detect].forEach(instance => {
3941
const config = instance.config;
4042
config.output.path = DIST_DIR;
4143
config.output.filename = `${instance.key}_bundle.js`;
@@ -64,6 +66,7 @@ const app = express()
6466
.use(middleware(gesture.compiler))
6567
.use(middleware(humanseg.compiler))
6668
.use(middleware(ocr.compiler))
69+
.use(middleware(detect.compiler))
6770
.use(express.static(DIST_DIR));
6871

6972
app.listen(port, () => {

e2e/tests/detect.test.ts

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
const detectResult = require('../dist/assets/detect.json');
2+
3+
describe('e2e test detect model', () => {
4+
const { paddlejs } = require('./global.d.ts');
5+
const CUR_URL = 'http://localhost:9898/';
6+
7+
beforeAll(async () => {
8+
await page.goto(CUR_URL);
9+
});
10+
11+
it('detect predict', async () => {
12+
page.on('console', msg => console.log('PAGE LOG:', msg.text()));
13+
const res = await page.evaluate(async () => {
14+
const $detect = document.querySelector('#detect');
15+
const det = paddlejs['detect'];
16+
await det.init();
17+
const result = await det.detect($detect);
18+
return result;
19+
});
20+
const gap = 0.02;
21+
22+
detectResult.forEach((item, index) => {
23+
// label 对比,若结果与预期不符,测试失败
24+
expect(item[0]).toEqual(res[index][0]);
25+
// 置信度对比,允许误差0.02
26+
expect(Math.abs(item[1] - res[index][1])).toBeLessThanOrEqual(gap);
27+
// 检测顶点坐标对比,允许误差0.02
28+
// left
29+
expect(Math.abs(item[2] - res[index][2])).toBeLessThanOrEqual(gap);
30+
// top
31+
expect(Math.abs(item[3] - res[index][3])).toBeLessThanOrEqual(gap);
32+
// right
33+
expect(Math.abs(item[4] - res[index][4])).toBeLessThanOrEqual(gap);
34+
// bottom
35+
expect(Math.abs(item[5] - res[index][5])).toBeLessThanOrEqual(gap);
36+
});
37+
});
38+
});

packages/paddlejs-core/src/mediaProcessor.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ export default class MediaProcessor {
147147
/**
148148
* 缩放成目标尺寸, keepRatio 为 true 则保持比例拉伸并居中,为 false 则变形拉伸为目标尺寸
149149
*/
150-
fitToTargetSize(image, imageDataInfo, opt) {
150+
fitToTargetSize(image, imageDataInfo, opt?) {
151151
const {
152152
keepRatio = true,
153153
inGPU = false,
@@ -244,4 +244,4 @@ export default class MediaProcessor {
244244
const scale = [sw / dw, sh / dh];
245245
return scale;
246246
}
247-
}
247+
}
+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
[中文版](./README_cn.md)
2+
3+
# detect
4+
5+
detect model is used to detect the position of label frame in the image.
6+
7+
# Usage
8+
9+
```js
10+
import * as det from '@paddlejs-models/detect';
11+
12+
// Load model
13+
await det.load();
14+
15+
// Get label index, confidence and coordinates
16+
const res = await det.detect(img);
17+
18+
res.forEach(item => {
19+
// Get label index
20+
console.log(item[0]);
21+
// Get label confidence
22+
console.log(item[1]);
23+
// Get label left coordinates
24+
console.log(item[2]);
25+
// Get label top coordinates
26+
console.log(item[3]);
27+
// Get label right coordinates
28+
console.log(item[4]);
29+
// Get label bottom coordinates
30+
console.log(item[5]);
31+
});
32+
```
33+
34+
# effect
35+
![img.png](https://user-images.githubusercontent.com/43414102/153805288-80f289bf-ca92-4788-b1dd-44854681a03f.png)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
[English](./README.md)
2+
3+
# detect
4+
5+
detect模型用于检测图像中label框选位置。
6+
7+
# 使用
8+
9+
```js
10+
import * as det from '@paddlejs-models/detect';
11+
12+
// 模型加载
13+
await det.load();
14+
15+
// 获取label对应索引、置信度、检测框选坐标
16+
const res = await det.detect(img);
17+
18+
res.forEach(item => {
19+
// 获取label对应索引
20+
console.log(item[0]);
21+
// 获取label置信度
22+
console.log(item[1]);
23+
// 获取检测框选left顶点
24+
console.log(item[2]);
25+
// 获取检测框选top顶点
26+
console.log(item[3]);
27+
// 获取检测框选right顶点
28+
console.log(item[4]);
29+
// 获取检测框选bottom顶点
30+
console.log(item[5]);
31+
});
32+
```
33+
34+
# 效果
35+
![img.png](https://user-images.githubusercontent.com/43414102/153805288-80f289bf-ca92-4788-b1dd-44854681a03f.png)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
<!DOCTYPE html>
2+
<html>
3+
<head>
4+
<meta charset="utf-8">
5+
<title>paddle detection demo</title>
6+
<meta name="viewport" content="width=device-width,minimum-scale=1.0,maximum-scale=1.0,user-scalable=no">
7+
<style>
8+
body {
9+
margin: 0 auto;
10+
width: 100%;
11+
height: 100%;
12+
}
13+
.wrapper {
14+
position: relative;
15+
width: 100%;
16+
height: 100%;
17+
margin: 0 auto;
18+
}
19+
#isLoading {
20+
position: fixed;
21+
top: 0;
22+
left: 0;
23+
right: 0;
24+
bottom: 0;
25+
width: 100vw;
26+
height: 100vh;
27+
background-color: rgba(0, 0, 0, .5);
28+
}
29+
#isLoading .loading-text {
30+
color: #fff;
31+
font-size: 24px;
32+
text-align: center;
33+
line-height: 100vh;
34+
}
35+
</style>
36+
</head>
37+
<body>
38+
<div class="wrapper">
39+
<img id="image" src="https://m.baidu.com/se/static/img/iphone/logo.png">
40+
<div id="tool">
41+
<input type="file" id="uploadImg">
42+
</div>
43+
<canvas id="canvas"></canvas>
44+
</div>
45+
46+
<div id="isLoading">
47+
<p class="loading-text center">loading……</p>
48+
</div>
49+
50+
</body>
51+
</html>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import * as det from '../src/index';
2+
import label from './label.json';
3+
4+
const loading = document.getElementById('isLoading');
5+
const inputElement = document.getElementById('uploadImg');
6+
const imgElement = document.getElementById('image') as HTMLImageElement;
7+
const canvasOutput = document.getElementById('canvas') as HTMLCanvasElement;
8+
9+
let isPreheat = true;
10+
11+
load();
12+
13+
async function load() {
14+
await det.init();
15+
loading.style.display = 'none';
16+
isPreheat = false;
17+
}
18+
19+
inputElement.addEventListener('change', (e: Event) => {
20+
imgElement.src = URL.createObjectURL((e.target as HTMLInputElement).files[0]);
21+
}, false);
22+
23+
imgElement.onload = async () => {
24+
if (isPreheat) {
25+
return;
26+
}
27+
// 获取检测值
28+
const res = await det.detect(imgElement);
29+
const imgHeight = imgElement.height;
30+
const imgWidth = imgElement.width;
31+
canvasOutput.width = imgWidth;
32+
canvasOutput.height = imgHeight;
33+
const ctx = canvasOutput.getContext('2d');
34+
ctx.drawImage(imgElement, 0, 0, canvasOutput.width, canvasOutput.height);
35+
ctx.beginPath();
36+
ctx.lineWidth = 1;
37+
ctx.strokeStyle = 'red';
38+
res.forEach(item => {
39+
// 获取检测框选坐标
40+
const left = Math.floor(item[2] * imgWidth);
41+
const top = Math.floor(item[3] * imgHeight);
42+
const right = Math.floor(item[4] * imgWidth);
43+
const bottom = Math.floor(item[5] * imgHeight);
44+
ctx.beginPath();
45+
// 绘制检测框选矩形
46+
ctx.rect(left, top, right - left, bottom - top);
47+
// 绘制label
48+
ctx.fillText(label[item[0]], left + 10, top + 10);
49+
ctx.stroke();
50+
});
51+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"0": "default",
3+
"1": "screw",
4+
"2": "nut"
5+
}

0 commit comments

Comments
 (0)