Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
getAxisInfo(ub::PoisonOp op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
unsigned rank = 1;
if (auto shape = dyn_cast<mlir::ShapedType>(op.getType()))
if (auto shape = dyn_cast<RankedTensorType>(op.getType()))
rank = shape.getRank();

// Poison values are never accessed, thus assume optimistic values.
Expand Down Expand Up @@ -1229,6 +1229,7 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
return rhs;
if (rhs.getRank() == 0)
return lhs;
assert(lhs.getRank() == rhs.getRank() && "Mismatched ranks");
DimVectorT contiguity;
DimVectorT divisibility;
DimVectorT constancy;
Expand Down
17 changes: 17 additions & 0 deletions test/TritonGPU/pipeline-assign-latencies.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1182,4 +1182,21 @@ tt.func @tc_gen5_mma_alloc_block_arg(%lb : index, %ub : index, %step : index,
}
tt.return
}

// -----

// Test that ub.poison producing a memdesc does not get treated like a tensor
// value in AxisInfo analysis.
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func public @minimal_crash(%lb: i32, %ub: i32) -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable> {
%c1 = arith.constant 1 : i32
%poison = ub.poison : !ttg.memdesc<2x2xf16, #shared, #smem, mutable>
%normal = ttg.local_alloc : () -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable>
%result = scf.for %i = %lb to %ub step %c1 iter_args(%current = %poison) -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable> : i32 {
scf.yield %normal : !ttg.memdesc<2x2xf16, #shared, #smem, mutable>
}
tt.return %result : !ttg.memdesc<2x2xf16, #shared, #smem, mutable>
}
}
22 changes: 11 additions & 11 deletions third_party/intel/lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,18 +293,19 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
AxisInfo
getAxisInfo(ub::PoisonOp op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
constexpr int64_t largePowerOf2 = int64_t(1) << 32;
// Poison values are never accessed, thus assume optimistic values.
if (auto shape = dyn_cast<mlir::ShapedType>(op.getType())) {
unsigned rank = shape.getRank();
return AxisInfo(
/*contiguity=*/AxisInfo::DimVectorT(rank, largePowerOf2),
/*divisibility=*/AxisInfo::DimVectorT(rank, largePowerOf2),
/*constancy=*/AxisInfo::DimVectorT(shape.getShape()));
unsigned rank = 1;
constexpr int64_t kMaxDivisor = highestPowOf2Divisor<int64_t>(0);
if (auto shape = dyn_cast<RankedTensorType>(op.getType()))
rank = shape.getRank();
else if (auto ptrTy = dyn_cast<PointerType>(op.getType())) {
if (auto tensorType = dyn_cast<RankedTensorType>(ptrTy.getPointeeType()))
rank = tensorType.getRank();
}

return AxisInfo(/*contiguity=*/{1}, /*divisibility=*/{largePowerOf2},
/*constancy=*/{1});
// Poison values are never accessed, thus assume optimistic values.
return AxisInfo(AxisInfo::DimVectorT(rank, kMaxDivisor),
AxisInfo::DimVectorT(rank, kMaxDivisor),
AxisInfo::DimVectorT(rank, kMaxDivisor));
}
};

Expand Down Expand Up @@ -1270,7 +1271,6 @@ void AxisInfoAnalysis::visitForOpInductionVar(
ModuleAxisInfoAnalysis::ModuleAxisInfoAnalysis(ModuleOp moduleOp)
: triton::ModuleAxisInfoAnalysis(moduleOp) {
funcMap.clear();

SmallVector<FunctionOpInterface> funcs;
for (auto root : getRoots()) {
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
Expand Down
Loading