@@ -221,6 +221,76 @@ def XeVM_BlockPrefetch2dOp : XeVM_Op<"blockprefetch2d">,
221
221
let hasVerifier = 1;
222
222
}
223
223
224
+ def XeVM_MatrixElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, F16, BF16]>;
225
+
226
+ /// Enum attribute of the different precision types.
227
+ def XeVM_PrecisionTypeAttr : I32EnumAttr<"PrecisionType",
228
+ "XeVM precision type",
229
+ [
230
+ I32EnumAttrCase<"UNUSED", 0, "unused">,
231
+ I32EnumAttrCase<"U8", 1, "u8">,
232
+ I32EnumAttrCase<"U4", 2, "u4">,
233
+ I32EnumAttrCase<"U2", 3, "u2">,
234
+ I32EnumAttrCase<"S8", 4, "i8">,
235
+ I32EnumAttrCase<"S4", 5, "i4">,
236
+ I32EnumAttrCase<"S2", 6, "i2">,
237
+ I32EnumAttrCase<"BF8", 7, "bf8">,
238
+ I32EnumAttrCase<"TF32", 8, "tf32">,
239
+ I32EnumAttrCase<"BF16", 9, "bf16">,
240
+ I32EnumAttrCase<"FP16", 10, "f16">
241
+ ]> {
242
+ let cppNamespace = "::mlir::xevm";
243
+ }
244
+
245
+ def XeVM_DPASOp : XeVM_Op<"dpas">,
246
+ Results<(outs FixedVectorOf<[XeVM_MatrixElemType]>:$d)>,
247
+ Arguments<(ins
248
+ FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$c,
249
+ FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$a,
250
+ FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$b,
251
+ XeVM_PrecisionTypeAttr:$pa,
252
+ XeVM_PrecisionTypeAttr:$pb,
253
+ I32Attr:$rc
254
+ )> {
255
+
256
+ let summary = "Matrix multiply-add";
257
+
258
+ let description = [{
259
+ The `xevm.dpas` operation is a matrix multiplication plus accumulation:
260
+
261
+ D = C + A x B
262
+
263
+ where the A, B, C input matrices and the result D have shapes:
264
+ D : MxN
265
+ C : MxN
266
+ A : MxK
267
+ B : KxN
268
+
269
+ Shape restrictions:
270
+ M : must be 1, 2, 4, or 8
271
+ N : fixed execution size, must be 16
272
+ K : systolic_depth * OPS_PER_CHAN
273
+ OPS_PER_CHAN
274
+ 1 : for TF32
275
+ 2 : for 16-bit precision(BF, HF)
276
+ 4 : for 8-bit precision (FP8, UB, B)
277
+ 8 : for less-then 8 bit precision (U4/S4, U2/S2).
278
+
279
+ If systolic_depth is 8, K would be 8, 16, 32, or 64 (based on OPS_PER_CHAN).
280
+ $a, $b, $c, $d - matrix A, B, C, D, respectively
281
+ $pa, $pb - precision of matrix A and B resepectively
282
+ $rc - repeat count
283
+
284
+ Further restrictions as well as more details can be found here:
285
+ https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
286
+ }];
287
+
288
+ let assemblyFormat = [{
289
+ operands ` ` `{` `pa` `=` $pa `,` `pb` `=` $pb `,` `rc` `=` $rc `}` attr-dict `:` functional-type(operands, results)
290
+ }];
291
+
292
+ // let hasVerifier = 1;
293
+ }
224
294
225
295
def XeVM_TargetAttr : XeVM_Attr<"XeVMTarget", "target"> {
226
296
let description = [{
0 commit comments