[mlir][xegpu] Add SIMT distribution support for GEMM transpose B case. #155517
[mlir][xegpu] Add SIMT distribution support for GEMM transpose B case. #155517charithaintc merged 50 commits intollvm:mainfrom
Conversation
…r_and_SliceAttr' into vector_bitcast_distr
| // communication. So each lane must own the required number of elements to | ||
| // perform the bitcast locally without cross-lane communication. | ||
| int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth; | ||
| if (outInnerBitsPerLane < inElemTyBitWidth) { |
There was a problem hiding this comment.
check the condition
srcInnerBitsPerLane = inElemTypeBitWidth x sourceLayout.getLaneData
if (outInnerBitsPerLane != srcInnerBitsPerLane)
There was a problem hiding this comment.
I thought about this again. sourceLayout.getLaneData is not available to us because we are trying to decide this here. I think we can only detect narrowing case only.
Widening case will always be valid because at this point if result already have a valid layout. Otherwise it means that result was not assigned a correct layout. That must be concern of the layout conflict maybe.
In any case, I added a check to verify if the result layout is valid and can be distributed to lanes.
There was a problem hiding this comment.
I see. I would move the check after the sourceLaneData is assigned. See comments below also.
| shapeCast.emitWarning("Expecting result type to be 1D or 2D vector."); | ||
| return; | ||
| } | ||
| // For 2D -> 2D shape cast, propagate the result layout to the source. |
There was a problem hiding this comment.
consider the restriction for now:
- same rank shape cast not allowed,
- always expand the dim not squeeze the dim,
- The new dims must be 1, and the original dims must not change
There was a problem hiding this comment.
fixed I also added this condition for now.
- Result layout can not be a slice layout and it must have same rank as result.
adam-smnk
left a comment
There was a problem hiding this comment.
Usually smaller PRs make reviews go faster but I'll bite 😉
Overall logic looks good, only minor comments.
| for (int64_t idx : permutation) { | ||
| newLayout.layout.push_back(laneLayout.layout[idx]); | ||
| newData.layout.push_back(laneData.layout[idx]); | ||
| laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx])); |
There was a problem hiding this comment.
how about add one more utilit to layout attribute, like getTransposedLayout(), so that it can be reused by sg_layout, or lane_layout.
Potentially, the isTransposeOf can be simplified to doing a transpose of input and compare whether they are same?
There was a problem hiding this comment.
agree. I will add this in a separate PR and clean up.
| func.func @vector_shape_cast_2d_to_1d_dim0_distributed(%arg0: !xegpu.tensor_desc<16x1xf16>, %arg1: !xegpu.tensor_desc<16xf16>) { | ||
| %c0 = arith.constant 0 : index | ||
| %3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x1xf16> -> vector<16x1xf16> | ||
| %2 = vector.shape_cast %3 : vector<16x1xf16> to vector<16xf16> |
There was a problem hiding this comment.
It seems contradict with the documentation
2) Shape cast must always expand the rank (e.g. 1D -> 2D).
Not sure why the code is passing. Maybe I missed something?
There was a problem hiding this comment.
sorry. I forgot to remove this test (CI was failing because of it). I removed this tests now.
There was a problem hiding this comment.
Shape cast must always expand the rank (e.g. 1D -> 2D).
If you refer to vector.shape_cast, a cast must preserve the same number of elements. Shape's rank can be freely changed up or down.
The two cases looked valid, it'd be good to understand why they failed.
If they can't be distributed, I'd leave them in as negative examples.
There was a problem hiding this comment.
@adam-smnk The restriction is there because we do not expect (for now) any narrowing shape casts. Shape cast is currently used to make the vector 2D after a 2D -> 1D reduction.
Adding back the tests as negative examples for now.
There was a problem hiding this comment.
my bad. pass is designed to fail if we can not assign a proper layout to ops. So I can not add the negative example in the same file AFAIK.
There was a problem hiding this comment.
Hmm, then it's sth rethink if it impacts testing.
A separate test file would be fine as this one's already pretty large. Not sure if verify-diagnostics can also test pass failures. TBD
There was a problem hiding this comment.
Not sure if verify-diagnostics can also test pass failures.
I think it can. challenge is doing it in same file. I did not find any examples. But I will give a try.
| return; | ||
| } | ||
| // Decide lane data based on whether the bitcast is narrowing or widening. | ||
| int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio |
There was a problem hiding this comment.
For narrowing bitcast, innerMostLaneData = outData[rank - 1] * bitCastRatio, instead of / bitCastRatio?
Put a TODO here?: check the layout conflict case here if ( innerMostLaneData * inElemTyBitWidth != outInnerBitsPerLane ).
There was a problem hiding this comment.
For narrowing bitcast, innerMostLaneData = outData[rank - 1] * bitCastRatio, instead of / bitCastRatio?
This is because in narrowing case source had higher bitwidth (e.g f32 -> f16)
Put a TODO here?: check the layout conflict case here if ( innerMostLaneData * inElemTyBitWidth != outInnerBitsPerLane ).
This is not required. At this point of layout propagation result layout is already a valid layout. We chose innerMostLaneData such that innerMostLaneData * inElemTyBitWidth == outInnerBitsPerLane.
| // communication. So each lane must own the required number of elements to | ||
| // perform the bitcast locally without cross-lane communication. | ||
| int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth; | ||
| if (outInnerBitsPerLane < inElemTyBitWidth) { |
There was a problem hiding this comment.
I see. I would move the check after the sourceLaneData is assigned. See comments below also.
|
@adam-smnk Can you take another look and/or approve? :-) |
This PR adds the features needed for supporting the GEMM with transpose B case.
Summary of changes.
1). Add distribution logic for
vector.bitcast,vector.transposeandmemref.extract_aligned_pointer_as_indexcases.2). Add layout propagation support for
vector.shape_cast,vector.broadcastandvector.bitcast3). Incorporate slice attribute and
DistributeLayoutAttrinterface with the core logic in layout prop.