Skip to content

[mlir][spirv] Add conversion pass to rewrite splat constant composite… #148910

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,11 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> {
"and replacing with supported ones";
}

def SPIRVReplicatedConstantCompositePass
: Pass<"spirv-promote-to-replicated-constants", "spirv::ModuleOp"> {
let summary = "Convert splat composite constants and spec constants to "
"corresponding replicated constant composite ops defined by "
"SPV_EXT_replicated_composites";
}

#endif // MLIR_DIALECT_SPIRV_TRANSFORMS_PASSES
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
set(LLVM_OPTIONAL_SOURCES
CanonicalizeGLPass.cpp
ConvertToReplicatedConstantCompositePass.cpp
DecorateCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
RewriteInsertsPass.cpp
Expand Down Expand Up @@ -30,6 +31,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion

add_mlir_dialect_library(MLIRSPIRVTransforms
CanonicalizeGLPass.cpp
ConvertToReplicatedConstantCompositePass.cpp
DecorateCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
RewriteInsertsPass.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
//===- ConvertToReplicatedConstantCompositePass.cpp -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert a splat composite spirv.Constant and
// spirv.SpecConstantComposite to spirv.EXT.ConstantCompositeReplicate and
// spirv.EXT.SpecConstantCompositeReplicate respectively.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"

namespace mlir::spirv {
#define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS
#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"

namespace {

static Type getArrayElemType(Attribute attr) {
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
return typedAttr.getType();
}

if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
return ArrayType::get(getArrayElemType(arrayAttr[0]), arrayAttr.size());
}

return nullptr;
}

static std::pair<Attribute, uint32_t>
getSplatAttrAndNumElements(Attribute valueAttr, Type valueType) {
auto compositeType = dyn_cast_or_null<spirv::CompositeType>(valueType);
if (!compositeType)
return {nullptr, 1};

if (auto splatAttr = dyn_cast<SplatElementsAttr>(valueAttr)) {
return {splatAttr.getSplatValue<Attribute>(), splatAttr.size()};
}

if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
if (llvm::all_equal(arrayAttr)) {
Attribute attr = arrayAttr[0];
uint32_t numElements = arrayAttr.size();

// Find the inner-most splat value for array of composites
auto [newAttr, newNumElements] =
getSplatAttrAndNumElements(attr, getArrayElemType(attr));
if (newAttr) {
attr = newAttr;
numElements *= newNumElements;
}
return {attr, numElements};
}
}

return {nullptr, 1};
}

struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(spirv::ConstantOp op,
PatternRewriter &rewriter) const override {
auto [attr, numElements] =
getSplatAttrAndNumElements(op.getValue(), op.getType());
if (!attr)
return rewriter.notifyMatchFailure(op, "composite is not splat");

if (numElements == 1)
return rewriter.notifyMatchFailure(op,
"composite has only one constituent");

rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
op, op.getType(), attr);
return success();
}
};

struct SpecConstantCompositeOpConversion final
: OpRewritePattern<spirv::SpecConstantCompositeOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op,
PatternRewriter &rewriter) const override {
auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
if (!compositeType)
return rewriter.notifyMatchFailure(op, "not a composite constant");

ArrayAttr constituents = op.getConstituents();
if (constituents.size() == 1)
return rewriter.notifyMatchFailure(op,
"composite has only one consituent");

if (!llvm::all_equal(constituents))
return rewriter.notifyMatchFailure(op, "composite is not splat");

auto splatConstituent = dyn_cast<FlatSymbolRefAttr>(constituents[0]);
if (!splatConstituent)
return rewriter.notifyMatchFailure(
op, "expected flat symbol reference for splat constituent");

rewriter.replaceOpWithNewOp<spirv::EXTSpecConstantCompositeReplicateOp>(
op, TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent);

return success();
}
};

struct ConvertToReplicatedConstantCompositePass final
: spirv::impl::SPIRVReplicatedConstantCompositePassBase<
ConvertToReplicatedConstantCompositePass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<ConstantOpConversion, SpecConstantCompositeOpConversion>(
context);
walkAndApplyPatterns(getOperation(), std::move(patterns));
}
};

} // namespace
} // namespace mlir::spirv
Loading