@@ -2239,8 +2239,8 @@ struct LoadOpConversion
22392239 auto opType = op.getType ();
22402240 // TODO: Override the OpType since conversion is still happening during Load
22412241 // lowering. Once we materialize ConvertLayoutOp this can be removed.
2242- if ( auto tensorTy = dyn_cast<RankedTensorType>(opType);
2243- hasSubgroup2DBlockEncoding (tensorTy))
2242+ auto tensorTy = dyn_cast<RankedTensorType>(opType);
2243+ if (tensorTy && hasSubgroup2DBlockEncoding (tensorTy))
22442244 opType = getDpasTypeFromCVTOp (op.getResult ());
22452245
22462246 // Determine the vectorization size
@@ -2256,9 +2256,11 @@ struct LoadOpConversion
22562256
22572257 if (isTensorPointerType (ptr.getType ())) {
22582258 // fallback to gather load.
2259- auto tensorType = cast<RankedTensorType>(opType);
2259+ // make sure we use the modified opType from above, "seeing through" any
2260+ // post-subgroup 2d block encoding CVT.
2261+ auto blockPtrTensorType = cast<RankedTensorType>(opType);
22602262 std::tie (ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr (
2261- loc, adaptor.getPtr (), tensorType , valueElemTy, rewriter,
2263+ loc, adaptor.getPtr (), blockPtrTensorType , valueElemTy, rewriter,
22622264 op.getBoundaryCheck (), op.getPadding ());
22632265 } else {
22642266 Value other = op.getOther ();
0 commit comments