|
3 | 3 | * @file reduce_mean
|
4 | 4 | */
|
5 | 5 |
|
| 6 | +function getGcd(a, b) { |
| 7 | + const max = Math.max(a, b); |
| 8 | + const min = Math.min(a, b); |
| 9 | + if (max % min === 0) { |
| 10 | + return min; |
| 11 | + } |
| 12 | + return getGcd(max % min, min); |
| 13 | +} |
| 14 | + |
| 15 | +function getLcm(a, b) { |
| 16 | + return a * b / getGcd(a, b); |
| 17 | +} |
6 | 18 | function mainFunc(
|
7 |
| - {}, |
8 |
| - { inputs_dim, dim } |
| 19 | + { origin }, |
| 20 | + { dim } |
9 | 21 | ) {
|
| 22 | + const { total_shape, height_shape, width_shape, channel } = origin; |
| 23 | + const batch_shape = total_shape / (width_shape * height_shape * channel); |
| 24 | + const shape = [batch_shape, channel, height_shape, width_shape]; |
| 25 | + let dimArr = []; |
| 26 | + if (dim instanceof Array) { |
| 27 | + dimArr = dim; |
| 28 | + } |
| 29 | + else { |
| 30 | + dimArr.push(dim); |
| 31 | + } |
| 32 | + const dimShape = dimArr.map(item => shape[item]); |
| 33 | + const totalDimShape = dimShape.reduce((prev, cur) => prev * cur); |
| 34 | + const arrGcd = dimShape.reduce((prev, cur) => getLcm(prev, cur)); |
| 35 | + const remainV = totalDimShape / arrGcd; |
| 36 | + |
| 37 | + let codeStr = 'float sum = 0.0;'; |
| 38 | + const strArr = [` |
| 39 | + sum += getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a) / float(${arrGcd}); |
| 40 | + `]; |
| 41 | + for (let i = 0; i < dimArr.length; i++) { |
| 42 | + const curDim = dimArr[i]; |
| 43 | + const curDimShape = shape[dimArr[i]]; |
| 44 | + const vname = `i${i}`; |
| 45 | + strArr.unshift(` |
| 46 | + for (int ${vname} = 0; ${vname} < ${curDimShape}; ${vname}++) { |
| 47 | + oPos[${curDim}] = ${vname}; |
| 48 | + `); |
| 49 | + strArr.push( |
| 50 | + ` |
| 51 | + } |
| 52 | + ` |
| 53 | + ); |
| 54 | + } |
| 55 | + |
| 56 | + codeStr += strArr.join('\n'); |
| 57 | + codeStr += ` |
| 58 | + o = sum / float(${remainV}); |
| 59 | + `; |
| 60 | + |
10 | 61 | return `
|
11 | 62 | // start函数
|
12 | 63 | void main(void) {
|
13 | 64 | ivec4 oPos = getOutputTensorPos();
|
14 | 65 | // 输出坐标转换为输入坐标
|
15 | 66 | float o = 0.0;
|
16 |
| - for (int i = 0; i < ${inputs_dim}; i++) { |
17 |
| - oPos[${dim}] = i; |
18 |
| - o += getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a); |
19 |
| - } |
20 |
| - o = o / float(${inputs_dim}); |
| 67 | + ${codeStr} |
21 | 68 | setOutput(o);
|
22 | 69 | }
|
23 | 70 | `;
|
24 | 71 | }
|
25 | 72 | export default {
|
26 | 73 | mainFunc,
|
| 74 | + params: [ |
| 75 | + 'dim' |
| 76 | + ], |
27 | 77 | textureFuncConf: {
|
28 | 78 | origin: ['getValueFromTensorPos']
|
29 | 79 | },
|
30 | 80 | behaviors: [
|
31 |
| - 'normalizeDim' |
| 81 | + |
32 | 82 | ]
|
33 | 83 | };
|
0 commit comments