@@ -293,18 +293,19 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
293293 AxisInfo
294294 getAxisInfo (ub::PoisonOp op,
295295 ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
296- constexpr int64_t largePowerOf2 = int64_t (1 ) << 32 ;
297- // Poison values are never accessed, thus assume optimistic values.
298- if (auto shape = dyn_cast<mlir::ShapedType>(op.getType ())) {
299- unsigned rank = shape.getRank ();
300- return AxisInfo (
301- /* contiguity=*/ AxisInfo::DimVectorT (rank, largePowerOf2),
302- /* divisibility=*/ AxisInfo::DimVectorT (rank, largePowerOf2),
303- /* constancy=*/ AxisInfo::DimVectorT (shape.getShape ()));
296+ unsigned rank = 1 ;
297+ constexpr int64_t kMaxDivisor = highestPowOf2Divisor<int64_t >(0 );
298+ if (auto shape = dyn_cast<RankedTensorType>(op.getType ()))
299+ rank = shape.getRank ();
300+ else if (auto ptrTy = dyn_cast<PointerType>(op.getType ())) {
301+ if (auto tensorType = dyn_cast<RankedTensorType>(ptrTy.getPointeeType ()))
302+ rank = tensorType.getRank ();
304303 }
305304
306- return AxisInfo (/* contiguity=*/ {1 }, /* divisibility=*/ {largePowerOf2},
307- /* constancy=*/ {1 });
305+ // Poison values are never accessed, thus assume optimistic values.
306+ return AxisInfo (AxisInfo::DimVectorT (rank, kMaxDivisor ),
307+ AxisInfo::DimVectorT (rank, kMaxDivisor ),
308+ AxisInfo::DimVectorT (rank, kMaxDivisor ));
308309 }
309310};
310311
@@ -1270,7 +1271,6 @@ void AxisInfoAnalysis::visitForOpInductionVar(
12701271ModuleAxisInfoAnalysis::ModuleAxisInfoAnalysis (ModuleOp moduleOp)
12711272 : triton::ModuleAxisInfoAnalysis(moduleOp) {
12721273 funcMap.clear ();
1273-
12741274 SmallVector<FunctionOpInterface> funcs;
12751275 for (auto root : getRoots ()) {
12761276 walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
0 commit comments