diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 31831e5623..55ed46196c 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -276,7 +276,7 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl { getAxisInfo(ub::PoisonOp op, ArrayRef *> operands) override { unsigned rank = 1; - if (auto shape = dyn_cast(op.getType())) + if (auto shape = dyn_cast(op.getType())) rank = shape.getRank(); // Poison values are never accessed, thus assume optimistic values. @@ -1229,6 +1229,7 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) { return rhs; if (rhs.getRank() == 0) return lhs; + assert(lhs.getRank() == rhs.getRank() && "Mismatched ranks"); DimVectorT contiguity; DimVectorT divisibility; DimVectorT constancy; diff --git a/test/TritonGPU/pipeline-assign-latencies.mlir b/test/TritonGPU/pipeline-assign-latencies.mlir index 3a38722e8e..be09ccab26 100644 --- a/test/TritonGPU/pipeline-assign-latencies.mlir +++ b/test/TritonGPU/pipeline-assign-latencies.mlir @@ -1182,4 +1182,21 @@ tt.func @tc_gen5_mma_alloc_block_arg(%lb : index, %ub : index, %step : index, } tt.return } + +// ----- + +// Test that ub.poison producing a memdesc does not get treated like a tensor +// value in AxisInfo analysis. +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func public @minimal_crash(%lb: i32, %ub: i32) -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable> { + %c1 = arith.constant 1 : i32 + %poison = ub.poison : !ttg.memdesc<2x2xf16, #shared, #smem, mutable> + %normal = ttg.local_alloc : () -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable> + %result = scf.for %i = %lb to %ub step %c1 iter_args(%current = %poison) -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable> : i32 { + scf.yield %normal : !ttg.memdesc<2x2xf16, #shared, #smem, mutable> + } + tt.return %result : !ttg.memdesc<2x2xf16, #shared, #smem, mutable> + } } diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp index 0e45ee0502..59c684dbf9 100644 --- a/third_party/intel/lib/Analysis/AxisInfo.cpp +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -293,18 +293,19 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl { AxisInfo getAxisInfo(ub::PoisonOp op, ArrayRef *> operands) override { - constexpr int64_t largePowerOf2 = int64_t(1) << 32; - // Poison values are never accessed, thus assume optimistic values. - if (auto shape = dyn_cast(op.getType())) { - unsigned rank = shape.getRank(); - return AxisInfo( - /*contiguity=*/AxisInfo::DimVectorT(rank, largePowerOf2), - /*divisibility=*/AxisInfo::DimVectorT(rank, largePowerOf2), - /*constancy=*/AxisInfo::DimVectorT(shape.getShape())); + unsigned rank = 1; + constexpr int64_t kMaxDivisor = highestPowOf2Divisor(0); + if (auto shape = dyn_cast(op.getType())) + rank = shape.getRank(); + else if (auto ptrTy = dyn_cast(op.getType())) { + if (auto tensorType = dyn_cast(ptrTy.getPointeeType())) + rank = tensorType.getRank(); } - return AxisInfo(/*contiguity=*/{1}, /*divisibility=*/{largePowerOf2}, - /*constancy=*/{1}); + // Poison values are never accessed, thus assume optimistic values. + return AxisInfo(AxisInfo::DimVectorT(rank, kMaxDivisor), + AxisInfo::DimVectorT(rank, kMaxDivisor), + AxisInfo::DimVectorT(rank, kMaxDivisor)); } }; @@ -1270,7 +1271,6 @@ void AxisInfoAnalysis::visitForOpInductionVar( ModuleAxisInfoAnalysis::ModuleAxisInfoAnalysis(ModuleOp moduleOp) : triton::ModuleAxisInfoAnalysis(moduleOp) { funcMap.clear(); - SmallVector funcs; for (auto root : getRoots()) { walk(