Skip to content

Commit ae058dd

Browse files
committed
Bug 1933012 [wpt PR 49343] - webnn: Support gatherElements in tflite converter, a=testonly
Automatic update from web-platform-tests webnn: Support gatherElements in tflite converter There is no TFLite operator to map directly WebNN gatherElements, but it can be supported with tfl.gather_nd, the test cases[1] and the doc[2] show the gather_nd can gether not only slices but also elements, but gather_nd can't gather along the axis, so the indices used by gather_nd need to be calculated by unravelling the flat index and adjusting value with the axis. The coordinates are calculated by the stride of each dimension. For example a shape (3, 2) array can map to the following coordinates: // index row, col // [[0, 1,] [[0, 0], [0, 1], // [2, 3,] => [1, 0], [1, 1] // [4, 5,]] [2, 0], [2, 1]] Adjusting the coordinates with WebNN indices operand along the axis, // unravelled index WebNN indices axis = 0 TFLite indices // [[0, 0], [0, 1], [[1, 0], [[1 ,0], [0, 1], // [1, 0], [1, 1] [2, 1], => [2, 0], [1, 1], // [2, 0], [2, 1]] [0, 2]] [0, 0], [2, 1]] [1] https://www.tensorflow.org/mlir/tfl_ops#tflgather_nd_tflgatherndop [2] https://www.tensorflow.org/guide/tensor_slicing#insert_data_into_tensors Bug: 40206287 Change-Id: I61ad75585405039da10af05f3300799fe8f32855 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6000111 Reviewed-by: ningxin hu <ningxin.huintel.com> Reviewed-by: Reilly Grant <reillygchromium.org> Commit-Queue: Junwei Fu <junwei.fuintel.com> Cr-Commit-Position: refs/heads/main{#1387213} -- wpt-commits: 038b09dd678084a3fdc828316da9a2f2a03981cf wpt-pr: 49343 UltraBlame original commit: edbd32557d5a79f628cad8f6f7b2a85003e0f334
1 parent eb1641e commit ae058dd

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

testing/web-platform/tests/webnn/conformance_tests/gatherElements.https.any.js

+74
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,80 @@ const gatherElementsTests = [
9999
}
100100
}
101101
},
102+
,
103+
{
104+
'name': 'gatherElements float32 2D input and int32 indices options.axis=0',
105+
'graph': {
106+
'inputs': {
107+
'gatherElementsInput': {
108+
'data': [
109+
-66.05901336669922, -68.9197006225586, -77.02045440673828,
110+
-26.158037185668945, 89.0337142944336, -45.89653396606445,
111+
43.84803771972656, 48.81806945800781, 51.79948425292969
112+
],
113+
'descriptor': {shape: [3, 3], dataType: 'float32'}
114+
},
115+
'gatherElementsIndices': {
116+
'data': [1, 0, 2, 2, 1, 0],
117+
'descriptor': {shape: [2, 3], dataType: 'int32'},
118+
'constant': true
119+
}
120+
},
121+
'operators': [{
122+
'name': 'gatherElements',
123+
'arguments': [
124+
{'input': 'gatherElementsInput'},
125+
{'indices': 'gatherElementsIndices'}, {'options': {'axis': 0}}
126+
],
127+
'outputs': 'gatherElementsOutput'
128+
}],
129+
'expectedOutputs': {
130+
'gatherElementsOutput': {
131+
'data': [
132+
-26.158037185668945, -68.9197006225586, 51.79948425292969,
133+
43.84803771972656, 89.0337142944336, -77.02045440673828
134+
],
135+
'descriptor': {shape: [2, 3], dataType: 'float32'}
136+
}
137+
}
138+
}
139+
},
140+
{
141+
'name': 'gatherElements float32 3D input and int32 indices options.axis=0',
142+
'graph': {
143+
'inputs': {
144+
'gatherElementsInput': {
145+
'data': [
146+
-66.05901336669922, -68.9197006225586, -77.02045440673828,
147+
-26.158037185668945, 89.0337142944336, -45.89653396606445,
148+
43.84803771972656, 48.81806945800781
149+
],
150+
'descriptor': {shape: [2, 2, 2], dataType: 'float32'}
151+
},
152+
'gatherElementsIndices': {
153+
'data': [1, 0, 0, 1],
154+
'descriptor': {shape: [1, 2, 2], dataType: 'int32'},
155+
'constant': true
156+
}
157+
},
158+
'operators': [{
159+
'name': 'gatherElements',
160+
'arguments': [
161+
{'input': 'gatherElementsInput'}, {'indices': 'gatherElementsIndices'}
162+
],
163+
'outputs': 'gatherElementsOutput'
164+
}],
165+
'expectedOutputs': {
166+
'gatherElementsOutput': {
167+
'data': [
168+
89.0337142944336, -68.9197006225586, -77.02045440673828,
169+
48.81806945800781
170+
],
171+
'descriptor': {shape: [1, 2, 2], dataType: 'float32'}
172+
}
173+
}
174+
}
175+
},
102176
{
103177
'name': 'gatherElements float32 3D input and int32 negative indices',
104178
'graph': {

0 commit comments

Comments
 (0)