Skip to content

Commit 9dd4e60

Browse files
authored
[XeVM] Add initial set of operations and conversion tests #421
This PR adds XeVM load/store/prefetch operations and their conversion patterns to llvm. We follow the Triton approach with slight changes (e.g., the cache attribute is handled differently).
1 parent b7129d8 commit 9dd4e60

File tree

11 files changed

+1060
-12
lines changed

11 files changed

+1060
-12
lines changed

include/gc/Conversion/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#ifndef GC_CONVERSION_PASSES_H
1010
#define GC_CONVERSION_PASSES_H
1111

12-
#include "gc/Conversion/XeVMToLLVM.h"
12+
#include "gc/Conversion/XeVMToLLVM/XeVMToLLVM.h"
1313

1414
namespace mlir {
1515

include/gc/Conversion/XeVMToLLVM/XeVMToLLVM.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class RewritePatternSet;
1717
class Pass;
1818

1919
#define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS
20-
#include "mlir/Conversion/Passes.h.inc"
20+
#include "gc/Conversion/Passes.h.inc"
2121

2222
void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns);
2323

include/gc/Dialect/LLVMIR/XeVMOps.td

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@ def XeVM_Dialect : Dialect {
1919
let name = "xevm";
2020
let cppNamespace = "::mlir::xevm";
2121
let dependentDialects = ["LLVM::LLVMDialect"];
22+
23+
let extraClassDeclaration = [{
24+
/// Get the name for the attribute used to specify cache control
25+
/// decorations.
26+
static constexpr ::llvm::StringRef getCacheControlsAttrName() {
27+
return ::llvm::StringLiteral("xevm.DecorationCacheControl");
28+
}
29+
}];
30+
2231
let useDefaultAttributePrinterParser = 1;
2332
}
2433

@@ -97,6 +106,8 @@ def XeVM_BlockLoad2dOp : XeVM_Op<"blockload2d">,
97106
$cache_control - an enumerator that sets the L1 and L3 cache behaviour
98107

99108
Notes:
109+
- pitch is the physical stride between the first columns of the current row and the subsequent row,
110+
this may include (possibly implicit) padding, alignment, or other factors.
100111
- the $transpose and $vnni_transform parameters are mutual exclusive
101112
- transposing the tile loaded is typically used for the B matrix operand
102113
(D = C + A * B), where A has row-major layout in registers and B should have column-major layout.
@@ -148,6 +159,8 @@ def XeVM_BlockStore2dOp : XeVM_Op<"blockstore2d">,
148159
$stored_val - the tile to store
149160

150161
Notes:
162+
- pitch is the physical stride between the first columns of the current row and the subsequent row,
163+
this may include (possibly implicit) padding, alignment, or other factors.
151164
- coordinate is provided in elements, while width and pitch are provided in bytes.
152165
}];
153166

@@ -161,6 +174,54 @@ def XeVM_BlockStore2dOp : XeVM_Op<"blockstore2d">,
161174
let hasVerifier = 1;
162175
}
163176

177+
def XeVM_BlockPrefetch2dOp : XeVM_Op<"blockprefetch2d">,
178+
Arguments<(ins
179+
Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
180+
I32:$base_width,
181+
I32:$base_height,
182+
I32:$base_pitch,
183+
I32:$x,
184+
I32:$y,
185+
I32Attr:$elem_size_in_bits,
186+
I32Attr:$tile_width,
187+
I32Attr:$tile_height,
188+
I32Attr:$v_blocks,
189+
DefaultValuedAttr<XeVM_L1LoadCacheControl, "::mlir::xevm::L1LoadCacheControl::DEFAULT">:$l1_cache_control,
190+
DefaultValuedAttr<XeVM_L3LoadCacheControl, "::mlir::xevm::L3LoadCacheControl::DEFAULT">:$l3_cache_control
191+
)> {
192+
193+
let summary = "2D block prefetch";
194+
195+
let description = [{
196+
The `xevm.blockprefetch2d` operation prefetches a two dimensional tile
197+
from a larger matrix residing in memory. The parameters are:
198+
$ptr - the base address of the matrix containing the tile to prefetch
199+
$base_width, $base_height, $base_pitch - the shape of the matrix
200+
$x, $y, $tile_width, $tile_height - the starting offsets and shape of tile to prefetch
201+
$elem_size_in_bits - the size in bits of the matrix element
202+
- 32 for f32, bf32
203+
- 16 for f16, int16, bf16
204+
- 8 for int8, int4, int2
205+
$v_blocks - number of tiles to prefetch
206+
$cache_control - an enumerator that sets the L1 and L3 cache behaviour
207+
208+
Notes:
209+
- pitch is the physical stride between the first columns of the current row and the subsequent row,
210+
this may include (possibly implicit) padding, alignment, or other factors.
211+
- coordinate is provided in elements, while width and pitch are provided in bytes.
212+
}];
213+
214+
let assemblyFormat = [{
215+
operands ` ` `{` `elem_size_in_bits` `=` $elem_size_in_bits `,` `tile_width` `=` $tile_width `,`
216+
`tile_height` `=` $tile_height `,` `v_blocks` `=` $v_blocks `,` `l1_cache_control` `=` $l1_cache_control `,`
217+
`l3_cache_control` `=` $l3_cache_control `}`
218+
attr-dict `:` `(` type(operands) `)`
219+
}];
220+
221+
let hasVerifier = 1;
222+
}
223+
224+
164225
def XeVM_TargetAttr : XeVM_Attr<"XeVMTarget", "target"> {
165226
let description = [{
166227
GPU target attribute for controlling compilation of targets. All

0 commit comments

Comments
 (0)