Skip to content

Commit 7053bca

Browse files
committed
Fix AxisInfo rank mismatch for poison tensor pointers
Signed-off-by: Witold Dziurdz <[email protected]>
1 parent 5158f2d commit 7053bca

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -293,18 +293,19 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
293293
AxisInfo
294294
getAxisInfo(ub::PoisonOp op,
295295
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
296-
constexpr int64_t largePowerOf2 = int64_t(1) << 32;
297-
// Poison values are never accessed, thus assume optimistic values.
298-
if (auto shape = dyn_cast<mlir::ShapedType>(op.getType())) {
299-
unsigned rank = shape.getRank();
300-
return AxisInfo(
301-
/*contiguity=*/AxisInfo::DimVectorT(rank, largePowerOf2),
302-
/*divisibility=*/AxisInfo::DimVectorT(rank, largePowerOf2),
303-
/*constancy=*/AxisInfo::DimVectorT(shape.getShape()));
296+
unsigned rank = 1;
297+
constexpr int64_t kMaxDivisor = highestPowOf2Divisor<int64_t>(0);
298+
if (auto shape = dyn_cast<RankedTensorType>(op.getType()))
299+
rank = shape.getRank();
300+
else if (auto ptrTy = dyn_cast<PointerType>(op.getType())) {
301+
if (auto tensorType = dyn_cast<RankedTensorType>(ptrTy.getPointeeType()))
302+
rank = tensorType.getRank();
304303
}
305304

306-
return AxisInfo(/*contiguity=*/{1}, /*divisibility=*/{largePowerOf2},
307-
/*constancy=*/{1});
305+
// Poison values are never accessed, thus assume optimistic values.
306+
return AxisInfo(AxisInfo::DimVectorT(rank, kMaxDivisor),
307+
AxisInfo::DimVectorT(rank, kMaxDivisor),
308+
AxisInfo::DimVectorT(rank, kMaxDivisor));
308309
}
309310
};
310311

@@ -1270,7 +1271,6 @@ void AxisInfoAnalysis::visitForOpInductionVar(
12701271
ModuleAxisInfoAnalysis::ModuleAxisInfoAnalysis(ModuleOp moduleOp)
12711272
: triton::ModuleAxisInfoAnalysis(moduleOp) {
12721273
funcMap.clear();
1273-
12741274
SmallVector<FunctionOpInterface> funcs;
12751275
for (auto root : getRoots()) {
12761276
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(

0 commit comments

Comments
 (0)