-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][spirv] Add conversion pass to rewrite splat constant composite… #148910
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv] Add conversion pass to rewrite splat constant composite… #148910
Conversation
…s to replicated form This adds a new SPIR-V dialect-level conversion pass `ConversionToReplicatedConstantCompositePass`. This pass looks for splat composite `spirv.Constant` or `spirv.SpecConstantComposite` and rewrites them into `spirv.EXT.ConstantCompositeReplicate` or `spirv.EXT.SpecConstantCompositeReplicate`, respectively. Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@llvm/pr-subscribers-mlir Author: Mohammadreza Ameri Mahabadian (mahabadm) Changes…s to replicated form This adds a new SPIR-V dialect-level conversion pass Full diff: https://github.com/llvm/llvm-project/pull/148910.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
index 2d9befe78001d..3c04db8396367 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
@@ -77,4 +77,11 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> {
"and replacing with supported ones";
}
+def SPIRVReplicatedConstantCompositePass
+ : Pass<"spirv-replicated-const-composite", "spirv::ModuleOp"> {
+ let summary = "Convert splat composite constants and spec constants to"
+ "corresponding replicated constant composite ops defined by"
+ "SPV_EXT_replicated_composites";
+}
+
#endif // MLIR_DIALECT_SPIRV_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index 68e0206e30a59..c675af9d048cc 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
set(LLVM_OPTIONAL_SOURCES
CanonicalizeGLPass.cpp
+ ConversionToReplicatedConstantCompositePass.cpp
DecorateCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
RewriteInsertsPass.cpp
@@ -30,6 +31,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion
add_mlir_dialect_library(MLIRSPIRVTransforms
CanonicalizeGLPass.cpp
+ ConversionToReplicatedConstantCompositePass.cpp
DecorateCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
RewriteInsertsPass.cpp
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
new file mode 100644
index 0000000000000..530a0f4aa67f5
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
@@ -0,0 +1,135 @@
+//===- ConversionToReplicatedConstantCompositePass.cpp
+//---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert a splat composite spirv.Constant and
+// spirv.SpecConstantComposite to spirv.EXT.ConstantCompositeReplicate and
+// spirv.EXT.SpecConstantCompositeReplicate respectively.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace spirv {
+#define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
+} // namespace spirv
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+Attribute getSplatAttribute(Attribute valueAttr, uint32_t splatCount) {
+ Attribute attr;
+ if (auto denseAttr = dyn_cast<DenseElementsAttr>(valueAttr)) {
+ if (denseAttr.isSplat()) {
+ attr = denseAttr.getSplatValue<Attribute>();
+ splatCount = denseAttr.size();
+ }
+ } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+ if (std::adjacent_find(arrayAttr.begin(), arrayAttr.end(),
+ std::not_equal_to<>()) == arrayAttr.end()) {
+ attr = arrayAttr[0];
+ splatCount = arrayAttr.size();
+ }
+ }
+
+ if (attr) {
+ if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+ if (isa<spirv::CompositeType>(typedAttr.getType()))
+ if (Attribute newAttr = getSplatAttribute(attr, splatCount))
+ attr = newAttr;
+ } else if (isa<ArrayAttr>(attr)) {
+ if (Attribute newAttr = getSplatAttribute(attr, splatCount))
+ attr = newAttr;
+ }
+ }
+
+ return attr;
+}
+
+} // namespace
+
+namespace {
+class ConversionToReplicatedConstantCompositePass
+ : public spirv::impl::SPIRVReplicatedConstantCompositePassBase<
+ ConversionToReplicatedConstantCompositePass> {
+public:
+ void runOnOperation() override;
+};
+
+class ConstantOpConversion : public OpRewritePattern<spirv::ConstantOp> {
+ using OpRewritePattern<spirv::ConstantOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::ConstantOp op,
+ PatternRewriter &rewriter) const override {
+ auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
+ if (!compositeType)
+ return rewriter.notifyMatchFailure(op, "not a composite constant");
+
+ uint32_t splatCount = 0;
+ Attribute splatAttr = getSplatAttribute(op.getValue(), splatCount);
+ if (!splatAttr)
+ return rewriter.notifyMatchFailure(op, "composite is not splat");
+
+ if (splatCount == 1)
+ return rewriter.notifyMatchFailure(op,
+ "composite has only one consituent");
+
+ rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
+ op, op.getType(), splatAttr);
+
+ return success();
+ }
+};
+
+class SpecConstantCompositeOpConversion
+ : public OpRewritePattern<spirv::SpecConstantCompositeOp> {
+ using OpRewritePattern<spirv::SpecConstantCompositeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op,
+ PatternRewriter &rewriter) const override {
+ auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
+ if (!compositeType)
+ return rewriter.notifyMatchFailure(op, "not a composite constant");
+
+ auto constituents = op.getConstituents();
+ if (constituents.size() == 1)
+ return rewriter.notifyMatchFailure(op,
+ "composite has only one consituent");
+
+ if (!(std::adjacent_find(constituents.begin(), constituents.end(),
+ std::not_equal_to<>()) == constituents.end()))
+ return rewriter.notifyMatchFailure(op, "composite is not splat");
+
+ auto splatConstituent =
+ dyn_cast<FlatSymbolRefAttr>(op.getConstituents()[0]);
+
+ rewriter.replaceOpWithNewOp<spirv::EXTSpecConstantCompositeReplicateOp>(
+ op, TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent);
+
+ return success();
+ }
+};
+
+void ConversionToReplicatedConstantCompositePass::runOnOperation() {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ patterns.add<ConstantOpConversion>(context);
+ patterns.add<SpecConstantCompositeOpConversion>(context);
+
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ signalPassFailure();
+ }
+}
+
+} // namespace
\ No newline at end of file
diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
new file mode 100644
index 0000000000000..f8cd4bb256bfe
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
@@ -0,0 +1,192 @@
+// RUN: mlir-opt -spirv-replicated-const-composite -split-input-file -verify-diagnostics %s -o - | FileCheck %s
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" {
+ %0 = spirv.Constant dense<2> : vector<3xi32>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : vector<3xi32>
+ spirv.ReturnValue %0 : vector<3xi32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
+ %0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
+ spirv.ReturnValue %0 : !spirv.array<3 x i32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+ %0 = spirv.Constant [[3 : i32, 3 : i32, 3 : i32], [3 : i32, 3 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+ %0 = spirv.Constant [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
+ %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
+ %0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+ %0 = spirv.Constant dense<3> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32>
+ %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" {
+ %0 = spirv.Constant dense<2.0> : vector<3xf32>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32>
+ spirv.ReturnValue %0 : vector<3xf32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" {
+ %0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32>
+ spirv.ReturnValue %0 : !spirv.array<3 x f32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+ %0 = spirv.Constant [[3.0 : f32, 3.0 : f32, 3.0 : f32], [3.0 : f32, 3.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+ %0 = spirv.Constant [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
+ %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
+ %0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+ %0 = spirv.Constant dense<3.0> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
+
+ spirv.SpecConstant @sc_i32_1 = 1 : i32
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+ spirv.SpecConstantComposite @scc_splat_array_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.array<3 x i32>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+ spirv.SpecConstantComposite @scc_splat_struct_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_i32 (@sc_i32_1) : vector<3xi32>
+ spirv.SpecConstantComposite @scc_splat_vector_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : vector<3 x i32>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_i32 (@sc_i32_1) : !spirv.arm.tensor<3xi32>
+ spirv.SpecConstantComposite @scc_splat_arm_tensor_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.arm.tensor<3xi32>
+
+ spirv.SpecConstant @sc_f32_1 = 1.0 : f32
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_f32 (@sc_f32_1) : !spirv.array<3 x f32>
+ spirv.SpecConstantComposite @scc_splat_array_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.array<3 x f32>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_f32 (@sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+ spirv.SpecConstantComposite @scc_splat_struct_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_f32 (@sc_f32_1) : vector<3xf32>
+ spirv.SpecConstantComposite @scc_splat_vector_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : vector<3 x f32>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_f32 (@sc_f32_1) : !spirv.arm.tensor<3xf32>
+ spirv.SpecConstantComposite @scc_splat_arm_tensor_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32>
+}
\ No newline at end of file
|
@llvm/pr-subscribers-mlir-spirv Author: Mohammadreza Ameri Mahabadian (mahabadm) Changes…s to replicated form This adds a new SPIR-V dialect-level conversion pass Full diff: https://github.com/llvm/llvm-project/pull/148910.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
index 2d9befe78001d..3c04db8396367 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
@@ -77,4 +77,11 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> {
"and replacing with supported ones";
}
+def SPIRVReplicatedConstantCompositePass
+ : Pass<"spirv-replicated-const-composite", "spirv::ModuleOp"> {
+ let summary = "Convert splat composite constants and spec constants to"
+ "corresponding replicated constant composite ops defined by"
+ "SPV_EXT_replicated_composites";
+}
+
#endif // MLIR_DIALECT_SPIRV_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index 68e0206e30a59..c675af9d048cc 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
set(LLVM_OPTIONAL_SOURCES
CanonicalizeGLPass.cpp
+ ConversionToReplicatedConstantCompositePass.cpp
DecorateCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
RewriteInsertsPass.cpp
@@ -30,6 +31,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion
add_mlir_dialect_library(MLIRSPIRVTransforms
CanonicalizeGLPass.cpp
+ ConversionToReplicatedConstantCompositePass.cpp
DecorateCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
RewriteInsertsPass.cpp
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
new file mode 100644
index 0000000000000..530a0f4aa67f5
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
@@ -0,0 +1,135 @@
+//===- ConversionToReplicatedConstantCompositePass.cpp
+//---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert a splat composite spirv.Constant and
+// spirv.SpecConstantComposite to spirv.EXT.ConstantCompositeReplicate and
+// spirv.EXT.SpecConstantCompositeReplicate respectively.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace spirv {
+#define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
+} // namespace spirv
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+Attribute getSplatAttribute(Attribute valueAttr, uint32_t splatCount) {
+ Attribute attr;
+ if (auto denseAttr = dyn_cast<DenseElementsAttr>(valueAttr)) {
+ if (denseAttr.isSplat()) {
+ attr = denseAttr.getSplatValue<Attribute>();
+ splatCount = denseAttr.size();
+ }
+ } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+ if (std::adjacent_find(arrayAttr.begin(), arrayAttr.end(),
+ std::not_equal_to<>()) == arrayAttr.end()) {
+ attr = arrayAttr[0];
+ splatCount = arrayAttr.size();
+ }
+ }
+
+ if (attr) {
+ if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+ if (isa<spirv::CompositeType>(typedAttr.getType()))
+ if (Attribute newAttr = getSplatAttribute(attr, splatCount))
+ attr = newAttr;
+ } else if (isa<ArrayAttr>(attr)) {
+ if (Attribute newAttr = getSplatAttribute(attr, splatCount))
+ attr = newAttr;
+ }
+ }
+
+ return attr;
+}
+
+} // namespace
+
+namespace {
+class ConversionToReplicatedConstantCompositePass
+ : public spirv::impl::SPIRVReplicatedConstantCompositePassBase<
+ ConversionToReplicatedConstantCompositePass> {
+public:
+ void runOnOperation() override;
+};
+
+class ConstantOpConversion : public OpRewritePattern<spirv::ConstantOp> {
+ using OpRewritePattern<spirv::ConstantOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::ConstantOp op,
+ PatternRewriter &rewriter) const override {
+ auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
+ if (!compositeType)
+ return rewriter.notifyMatchFailure(op, "not a composite constant");
+
+ uint32_t splatCount = 0;
+ Attribute splatAttr = getSplatAttribute(op.getValue(), splatCount);
+ if (!splatAttr)
+ return rewriter.notifyMatchFailure(op, "composite is not splat");
+
+ if (splatCount == 1)
+ return rewriter.notifyMatchFailure(op,
+ "composite has only one consituent");
+
+ rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
+ op, op.getType(), splatAttr);
+
+ return success();
+ }
+};
+
+class SpecConstantCompositeOpConversion
+ : public OpRewritePattern<spirv::SpecConstantCompositeOp> {
+ using OpRewritePattern<spirv::SpecConstantCompositeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op,
+ PatternRewriter &rewriter) const override {
+ auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
+ if (!compositeType)
+ return rewriter.notifyMatchFailure(op, "not a composite constant");
+
+ auto constituents = op.getConstituents();
+ if (constituents.size() == 1)
+ return rewriter.notifyMatchFailure(op,
+ "composite has only one consituent");
+
+ if (!(std::adjacent_find(constituents.begin(), constituents.end(),
+ std::not_equal_to<>()) == constituents.end()))
+ return rewriter.notifyMatchFailure(op, "composite is not splat");
+
+ auto splatConstituent =
+ dyn_cast<FlatSymbolRefAttr>(op.getConstituents()[0]);
+
+ rewriter.replaceOpWithNewOp<spirv::EXTSpecConstantCompositeReplicateOp>(
+ op, TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent);
+
+ return success();
+ }
+};
+
+void ConversionToReplicatedConstantCompositePass::runOnOperation() {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ patterns.add<ConstantOpConversion>(context);
+ patterns.add<SpecConstantCompositeOpConversion>(context);
+
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ signalPassFailure();
+ }
+}
+
+} // namespace
\ No newline at end of file
diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
new file mode 100644
index 0000000000000..f8cd4bb256bfe
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
@@ -0,0 +1,192 @@
+// RUN: mlir-opt -spirv-replicated-const-composite -split-input-file -verify-diagnostics %s -o - | FileCheck %s
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" {
+ %0 = spirv.Constant dense<2> : vector<3xi32>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : vector<3xi32>
+ spirv.ReturnValue %0 : vector<3xi32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
+ %0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
+ spirv.ReturnValue %0 : !spirv.array<3 x i32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+ %0 = spirv.Constant [[3 : i32, 3 : i32, 3 : i32], [3 : i32, 3 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+ %0 = spirv.Constant [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
+ %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
+ %0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+ %0 = spirv.Constant dense<3> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32>
+ %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" {
+ %0 = spirv.Constant dense<2.0> : vector<3xf32>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32>
+ spirv.ReturnValue %0 : vector<3xf32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" {
+ %0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32>
+ spirv.ReturnValue %0 : !spirv.array<3 x f32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+ %0 = spirv.Constant [[3.0 : f32, 3.0 : f32, 3.0 : f32], [3.0 : f32, 3.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+ %0 = spirv.Constant [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
+ %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
+ %0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+ // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+ %0 = spirv.Constant dense<3.0> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
+
+ spirv.SpecConstant @sc_i32_1 = 1 : i32
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+ spirv.SpecConstantComposite @scc_splat_array_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.array<3 x i32>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+ spirv.SpecConstantComposite @scc_splat_struct_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_i32 (@sc_i32_1) : vector<3xi32>
+ spirv.SpecConstantComposite @scc_splat_vector_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : vector<3 x i32>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_i32 (@sc_i32_1) : !spirv.arm.tensor<3xi32>
+ spirv.SpecConstantComposite @scc_splat_arm_tensor_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.arm.tensor<3xi32>
+
+ spirv.SpecConstant @sc_f32_1 = 1.0 : f32
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_f32 (@sc_f32_1) : !spirv.array<3 x f32>
+ spirv.SpecConstantComposite @scc_splat_array_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.array<3 x f32>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_f32 (@sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+ spirv.SpecConstantComposite @scc_splat_struct_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_f32 (@sc_f32_1) : vector<3xf32>
+ spirv.SpecConstantComposite @scc_splat_vector_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : vector<3 x f32>
+
+ // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_f32 (@sc_f32_1) : !spirv.arm.tensor<3xf32>
+ spirv.SpecConstantComposite @scc_splat_arm_tensor_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32>
+}
\ No newline at end of file
|
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two more nits, but LGTM in general. Thanks for addressing the feedback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't looked at the logic yet, left some comments related to code organization first
mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's all from me. Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall, just some remaining coding style issues
mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
Outdated
Show resolved
Hide resolved
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@kuhar In hindsight, I realized that it is better to move composite type check into |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % nit. Thanks for the changes.
mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp
Outdated
Show resolved
Hide resolved
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
…s to replicated form
This adds a new SPIR-V dialect-level conversion pass
ConversionToReplicatedConstantCompositePass
. This pass looks for splat compositespirv.Constant
orspirv.SpecConstantComposite
and rewrites them intospirv.EXT.ConstantCompositeReplicate
orspirv.EXT.SpecConstantCompositeReplicate
, respectively.