From 5158f2d4588c745224921a66f90d17c7fea3001c Mon Sep 17 00:00:00 2001 From: neildhar Date: Thu, 23 Oct 2025 22:24:04 -0700 Subject: [PATCH 1/2] Fix AxisInfo handling of PoisonOp producing MemDesc (#8489) AxisInfo analysis currently retrieves the rank from any `ShapedType` producing `PoisonOp`. This is a problem if the `PoisonOp` actually produces a `MemDesc`, since the value produced by the `PoisonOp` may flow into the same value as some other `MemDesc` producing operation, which will have been assigned the "pessimistic state" and have rank 1. When we attempt to join the two, the ranks will not match, potentially resulting in a crash. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --- lib/Analysis/AxisInfo.cpp | 3 ++- test/TritonGPU/pipeline-assign-latencies.mlir | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 31831e5623..55ed46196c 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -276,7 +276,7 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl { getAxisInfo(ub::PoisonOp op, ArrayRef *> operands) override { unsigned rank = 1; - if (auto shape = dyn_cast(op.getType())) + if (auto shape = dyn_cast(op.getType())) rank = shape.getRank(); // Poison values are never accessed, thus assume optimistic values. @@ -1229,6 +1229,7 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) { return rhs; if (rhs.getRank() == 0) return lhs; + assert(lhs.getRank() == rhs.getRank() && "Mismatched ranks"); DimVectorT contiguity; DimVectorT divisibility; DimVectorT constancy; diff --git a/test/TritonGPU/pipeline-assign-latencies.mlir b/test/TritonGPU/pipeline-assign-latencies.mlir index 3a38722e8e..be09ccab26 100644 --- a/test/TritonGPU/pipeline-assign-latencies.mlir +++ b/test/TritonGPU/pipeline-assign-latencies.mlir @@ -1182,4 +1182,21 @@ tt.func @tc_gen5_mma_alloc_block_arg(%lb : index, %ub : index, %step : index, } tt.return } + +// ----- + +// Test that ub.poison producing a memdesc does not get treated like a tensor +// value in AxisInfo analysis. +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func public @minimal_crash(%lb: i32, %ub: i32) -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable> { + %c1 = arith.constant 1 : i32 + %poison = ub.poison : !ttg.memdesc<2x2xf16, #shared, #smem, mutable> + %normal = ttg.local_alloc : () -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable> + %result = scf.for %i = %lb to %ub step %c1 iter_args(%current = %poison) -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable> : i32 { + scf.yield %normal : !ttg.memdesc<2x2xf16, #shared, #smem, mutable> + } + tt.return %result : !ttg.memdesc<2x2xf16, #shared, #smem, mutable> + } } From 7053bca4f6f57feb6884a44455ef7d9cab62dfab Mon Sep 17 00:00:00 2001 From: Witold Dziurdz Date: Sat, 22 Nov 2025 17:06:17 +0000 Subject: [PATCH 2/2] Fix AxisInfo rank mismatch for poison tensor pointers Signed-off-by: Witold Dziurdz --- third_party/intel/lib/Analysis/AxisInfo.cpp | 22 ++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp index 0e45ee0502..59c684dbf9 100644 --- a/third_party/intel/lib/Analysis/AxisInfo.cpp +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -293,18 +293,19 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl { AxisInfo getAxisInfo(ub::PoisonOp op, ArrayRef *> operands) override { - constexpr int64_t largePowerOf2 = int64_t(1) << 32; - // Poison values are never accessed, thus assume optimistic values. - if (auto shape = dyn_cast(op.getType())) { - unsigned rank = shape.getRank(); - return AxisInfo( - /*contiguity=*/AxisInfo::DimVectorT(rank, largePowerOf2), - /*divisibility=*/AxisInfo::DimVectorT(rank, largePowerOf2), - /*constancy=*/AxisInfo::DimVectorT(shape.getShape())); + unsigned rank = 1; + constexpr int64_t kMaxDivisor = highestPowOf2Divisor(0); + if (auto shape = dyn_cast(op.getType())) + rank = shape.getRank(); + else if (auto ptrTy = dyn_cast(op.getType())) { + if (auto tensorType = dyn_cast(ptrTy.getPointeeType())) + rank = tensorType.getRank(); } - return AxisInfo(/*contiguity=*/{1}, /*divisibility=*/{largePowerOf2}, - /*constancy=*/{1}); + // Poison values are never accessed, thus assume optimistic values. + return AxisInfo(AxisInfo::DimVectorT(rank, kMaxDivisor), + AxisInfo::DimVectorT(rank, kMaxDivisor), + AxisInfo::DimVectorT(rank, kMaxDivisor)); } }; @@ -1270,7 +1271,6 @@ void AxisInfoAnalysis::visitForOpInductionVar( ModuleAxisInfoAnalysis::ModuleAxisInfoAnalysis(ModuleOp moduleOp) : triton::ModuleAxisInfoAnalysis(moduleOp) { funcMap.clear(); - SmallVector funcs; for (auto root : getRoots()) { walk(