From 936435e95e1cab2e888eb8483bd3aeaf3f91857b Mon Sep 17 00:00:00 2001 From: Mohammadreza Ameri Mahabadian Date: Wed, 2 Jul 2025 09:00:45 +0100 Subject: [PATCH 01/10] [mlir][spirv] Add conversion pass to rewrite splat constant composites to replicated form This adds a new SPIR-V dialect-level conversion pass `ConversionToReplicatedConstantCompositePass`. This pass looks for splat composite `spirv.Constant` or `spirv.SpecConstantComposite` and rewrites them into `spirv.EXT.ConstantCompositeReplicate` or `spirv.EXT.SpecConstantCompositeReplicate`, respectively. Signed-off-by: Mohammadreza Ameri Mahabadian --- .../mlir/Dialect/SPIRV/Transforms/Passes.td | 7 + .../Dialect/SPIRV/Transforms/CMakeLists.txt | 2 + ...rsionToReplicatedConstantCompositePass.cpp | 135 ++++++++++++ .../replicated-const-composites.mlir | 192 ++++++++++++++++++ 4 files changed, 336 insertions(+) create mode 100644 mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp create mode 100644 mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td index 2d9befe78001d..3c04db8396367 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td @@ -77,4 +77,11 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> { "and replacing with supported ones"; } +def SPIRVReplicatedConstantCompositePass + : Pass<"spirv-replicated-const-composite", "spirv::ModuleOp"> { + let summary = "Convert splat composite constants and spec constants to" + "corresponding replicated constant composite ops defined by" + "SPV_EXT_replicated_composites"; +} + #endif // MLIR_DIALECT_SPIRV_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt index 68e0206e30a59..c675af9d048cc 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ set(LLVM_OPTIONAL_SOURCES CanonicalizeGLPass.cpp + ConversionToReplicatedConstantCompositePass.cpp DecorateCompositeTypeLayoutPass.cpp LowerABIAttributesPass.cpp RewriteInsertsPass.cpp @@ -30,6 +31,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion add_mlir_dialect_library(MLIRSPIRVTransforms CanonicalizeGLPass.cpp + ConversionToReplicatedConstantCompositePass.cpp DecorateCompositeTypeLayoutPass.cpp LowerABIAttributesPass.cpp RewriteInsertsPass.cpp diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp new file mode 100644 index 0000000000000..530a0f4aa67f5 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp @@ -0,0 +1,135 @@ +//===- ConversionToReplicatedConstantCompositePass.cpp +//---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert a splat composite spirv.Constant and +// spirv.SpecConstantComposite to spirv.EXT.ConstantCompositeReplicate and +// spirv.EXT.SpecConstantCompositeReplicate respectively. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/Transforms/Passes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace spirv { +#define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS +#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc" +} // namespace spirv +} // namespace mlir + +using namespace mlir; + +namespace { + +Attribute getSplatAttribute(Attribute valueAttr, uint32_t splatCount) { + Attribute attr; + if (auto denseAttr = dyn_cast(valueAttr)) { + if (denseAttr.isSplat()) { + attr = denseAttr.getSplatValue(); + splatCount = denseAttr.size(); + } + } else if (auto arrayAttr = dyn_cast(valueAttr)) { + if (std::adjacent_find(arrayAttr.begin(), arrayAttr.end(), + std::not_equal_to<>()) == arrayAttr.end()) { + attr = arrayAttr[0]; + splatCount = arrayAttr.size(); + } + } + + if (attr) { + if (auto typedAttr = dyn_cast(attr)) { + if (isa(typedAttr.getType())) + if (Attribute newAttr = getSplatAttribute(attr, splatCount)) + attr = newAttr; + } else if (isa(attr)) { + if (Attribute newAttr = getSplatAttribute(attr, splatCount)) + attr = newAttr; + } + } + + return attr; +} + +} // namespace + +namespace { +class ConversionToReplicatedConstantCompositePass + : public spirv::impl::SPIRVReplicatedConstantCompositePassBase< + ConversionToReplicatedConstantCompositePass> { +public: + void runOnOperation() override; +}; + +class ConstantOpConversion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spirv::ConstantOp op, + PatternRewriter &rewriter) const override { + auto compositeType = dyn_cast_or_null(op.getType()); + if (!compositeType) + return rewriter.notifyMatchFailure(op, "not a composite constant"); + + uint32_t splatCount = 0; + Attribute splatAttr = getSplatAttribute(op.getValue(), splatCount); + if (!splatAttr) + return rewriter.notifyMatchFailure(op, "composite is not splat"); + + if (splatCount == 1) + return rewriter.notifyMatchFailure(op, + "composite has only one consituent"); + + rewriter.replaceOpWithNewOp( + op, op.getType(), splatAttr); + + return success(); + } +}; + +class SpecConstantCompositeOpConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op, + PatternRewriter &rewriter) const override { + auto compositeType = dyn_cast_or_null(op.getType()); + if (!compositeType) + return rewriter.notifyMatchFailure(op, "not a composite constant"); + + auto constituents = op.getConstituents(); + if (constituents.size() == 1) + return rewriter.notifyMatchFailure(op, + "composite has only one consituent"); + + if (!(std::adjacent_find(constituents.begin(), constituents.end(), + std::not_equal_to<>()) == constituents.end())) + return rewriter.notifyMatchFailure(op, "composite is not splat"); + + auto splatConstituent = + dyn_cast(op.getConstituents()[0]); + + rewriter.replaceOpWithNewOp( + op, TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent); + + return success(); + } +}; + +void ConversionToReplicatedConstantCompositePass::runOnOperation() { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add(context); + patterns.add(context); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace \ No newline at end of file diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir new file mode 100644 index 0000000000000..f8cd4bb256bfe --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir @@ -0,0 +1,192 @@ +// RUN: mlir-opt -spirv-replicated-const-composite -split-input-file -verify-diagnostics %s -o - | FileCheck %s + +spirv.module Logical GLSL450 { + spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" { + %0 = spirv.Constant dense<2> : vector<3xi32> + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : vector<3xi32> + spirv.ReturnValue %0 : vector<3xi32> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" { + %0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32> + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32> + spirv.ReturnValue %0 : !spirv.array<3 x i32> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" { + // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>> + %0 = spirv.Constant [[3 : i32, 3 : i32, 3 : i32], [3 : i32, 3 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" { + // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>> + %0 = spirv.Constant [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" { + %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>> + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>> + spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" { + %0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>> + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>> + spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" { + // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>> + %0 = spirv.Constant dense<3> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { + // CHECK: spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32> + %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" { + %0 = spirv.Constant dense<2.0> : vector<3xf32> + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32> + spirv.ReturnValue %0 : vector<3xf32> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" { + %0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32> + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32> + spirv.ReturnValue %0 : !spirv.array<3 x f32> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @splat_array_of_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" { + // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>> + %0 = spirv.Constant [[3.0 : f32, 3.0 : f32, 3.0 : f32], [3.0 : f32, 3.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" { + // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>> + %0 = spirv.Constant [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" { + %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> + spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" { + %0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>> + spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" { + // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>> + %0 = spirv.Constant dense<3.0> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { + // CHECK: spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32> + %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> + } +} + +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce { + + spirv.SpecConstant @sc_i32_1 = 1 : i32 + + // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32> + spirv.SpecConstantComposite @scc_splat_array_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.array<3 x i32> + + // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)> + spirv.SpecConstantComposite @scc_splat_struct_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.struct<(i32, i32, i32)> + + // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_i32 (@sc_i32_1) : vector<3xi32> + spirv.SpecConstantComposite @scc_splat_vector_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : vector<3 x i32> + + // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_i32 (@sc_i32_1) : !spirv.arm.tensor<3xi32> + spirv.SpecConstantComposite @scc_splat_arm_tensor_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.arm.tensor<3xi32> + + spirv.SpecConstant @sc_f32_1 = 1.0 : f32 + + // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_f32 (@sc_f32_1) : !spirv.array<3 x f32> + spirv.SpecConstantComposite @scc_splat_array_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.array<3 x f32> + + // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_f32 (@sc_f32_1) : !spirv.struct<(f32, f32, f32)> + spirv.SpecConstantComposite @scc_splat_struct_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.struct<(f32, f32, f32)> + + // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_f32 (@sc_f32_1) : vector<3xf32> + spirv.SpecConstantComposite @scc_splat_vector_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : vector<3 x f32> + + // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_f32 (@sc_f32_1) : !spirv.arm.tensor<3xf32> + spirv.SpecConstantComposite @scc_splat_arm_tensor_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32> +} \ No newline at end of file From 14b70ce5742cacc84b8d2833e62eface967c07c8 Mon Sep 17 00:00:00 2001 From: Mohammadreza Ameri Mahabadian Date: Wed, 16 Jul 2025 08:20:03 +0100 Subject: [PATCH 02/10] Minor bug fix and revision Signed-off-by: Mohammadreza Ameri Mahabadian --- .../Transforms/ConversionToReplicatedConstantCompositePass.cpp | 2 +- .../Dialect/SPIRV/Transforms/replicated-const-composites.mlir | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp index 530a0f4aa67f5..590fa6e9d684a 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp @@ -28,7 +28,7 @@ using namespace mlir; namespace { -Attribute getSplatAttribute(Attribute valueAttr, uint32_t splatCount) { +Attribute getSplatAttribute(Attribute valueAttr, uint32_t &splatCount) { Attribute attr; if (auto denseAttr = dyn_cast(valueAttr)) { if (denseAttr.isSplat()) { diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir index f8cd4bb256bfe..a7a9ca25edc6a 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir @@ -160,7 +160,7 @@ spirv.module Logical GLSL450 { // ----- -spirv.module Logical GLSL450 requires #spirv.vce { +spirv.module Logical GLSL450 { spirv.SpecConstant @sc_i32_1 = 1 : i32 From 51ce7ac82a887a683010b54c31c95ec7d5e9bd54 Mon Sep 17 00:00:00 2001 From: Mohammadreza Ameri Mahabadian Date: Wed, 16 Jul 2025 15:31:21 +0100 Subject: [PATCH 03/10] Addressing code review comments Signed-off-by: Mohammadreza Ameri Mahabadian --- .../mlir/Dialect/SPIRV/Transforms/Passes.td | 2 +- ...rsionToReplicatedConstantCompositePass.cpp | 10 +- .../replicated-const-composites.mlir | 145 ++++++++++++++++-- 3 files changed, 135 insertions(+), 22 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td index 3c04db8396367..bc1c1f075e09b 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td @@ -78,7 +78,7 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> { } def SPIRVReplicatedConstantCompositePass - : Pass<"spirv-replicated-const-composite", "spirv::ModuleOp"> { + : Pass<"spirv-convert-to-replicated-const-composite", "spirv::ModuleOp"> { let summary = "Convert splat composite constants and spec constants to" "corresponding replicated constant composite ops defined by" "SPV_EXT_replicated_composites"; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp index 590fa6e9d684a..da1777ab42f97 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp @@ -1,5 +1,4 @@ -//===- ConversionToReplicatedConstantCompositePass.cpp -//---------------------------===// +//===- ConversionToReplicatedConstantCompositePass.cpp --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -83,7 +82,7 @@ class ConstantOpConversion : public OpRewritePattern { if (splatCount == 1) return rewriter.notifyMatchFailure(op, - "composite has only one consituent"); + "composite has only one constituent"); rewriter.replaceOpWithNewOp( op, op.getType(), splatAttr); @@ -102,7 +101,7 @@ class SpecConstantCompositeOpConversion if (!compositeType) return rewriter.notifyMatchFailure(op, "not a composite constant"); - auto constituents = op.getConstituents(); + ArrayAttr constituents = op.getConstituents(); if (constituents.size() == 1) return rewriter.notifyMatchFailure(op, "composite has only one consituent"); @@ -113,6 +112,9 @@ class SpecConstantCompositeOpConversion auto splatConstituent = dyn_cast(op.getConstituents()[0]); + if (!splatConstituent) + return rewriter.notifyMatchFailure( + op, "expected flat symbol reference for splat constituent"); rewriter.replaceOpWithNewOp( op, TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent); diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir index a7a9ca25edc6a..4431f417635b8 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir @@ -1,9 +1,9 @@ -// RUN: mlir-opt -spirv-replicated-const-composite -split-input-file -verify-diagnostics %s -o - | FileCheck %s +// RUN: mlir-opt --spirv-convert-to-replicated-const-composite --split-input-file --verify-diagnostics %s -o - | FileCheck %s spirv.module Logical GLSL450 { spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" { - %0 = spirv.Constant dense<2> : vector<3xi32> // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : vector<3xi32> + %0 = spirv.Constant dense<2> : vector<3xi32> spirv.ReturnValue %0 : vector<3xi32> } } @@ -12,8 +12,8 @@ spirv.module Logical GLSL450 { spirv.module Logical GLSL450 { spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" { - %0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32> // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32> + %0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32> spirv.ReturnValue %0 : !spirv.array<3 x i32> } } @@ -22,7 +22,7 @@ spirv.module Logical GLSL450 { spirv.module Logical GLSL450 { spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" { - // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>> + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>> %0 = spirv.Constant [[3 : i32, 3 : i32, 3 : i32], [3 : i32, 3 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>> spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>> } @@ -32,7 +32,7 @@ spirv.module Logical GLSL450 { spirv.module Logical GLSL450 { spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" { - // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>> + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>> %0 = spirv.Constant [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>> spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>> } @@ -42,8 +42,8 @@ spirv.module Logical GLSL450 { spirv.module Logical GLSL450 { spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" { - %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>> // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>> + %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>> spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>> } } @@ -52,8 +52,8 @@ spirv.module Logical GLSL450 { spirv.module Logical GLSL450 { spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" { - %0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>> // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>> + %0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>> spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>> } } @@ -62,7 +62,7 @@ spirv.module Logical GLSL450 { spirv.module Logical GLSL450 { spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" { - // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>> + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>> %0 = spirv.Constant dense<3> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>> spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>> } @@ -72,7 +72,7 @@ spirv.module Logical GLSL450 { spirv.module Logical GLSL450 { spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { - // CHECK: spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32> + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32> %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32> spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> } @@ -82,8 +82,8 @@ spirv.module Logical GLSL450 { spirv.module Logical GLSL450 { spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" { - %0 = spirv.Constant dense<2.0> : vector<3xf32> // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32> + %0 = spirv.Constant dense<2.0> : vector<3xf32> spirv.ReturnValue %0 : vector<3xf32> } } @@ -92,8 +92,8 @@ spirv.module Logical GLSL450 { spirv.module Logical GLSL450 { spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" { - %0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32> // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32> + %0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32> spirv.ReturnValue %0 : !spirv.array<3 x f32> } } @@ -102,7 +102,7 @@ spirv.module Logical GLSL450 { spirv.module Logical GLSL450 { spirv.func @splat_array_of_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" { - // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>> + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>> %0 = spirv.Constant [[3.0 : f32, 3.0 : f32, 3.0 : f32], [3.0 : f32, 3.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>> spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>> } @@ -112,7 +112,7 @@ spirv.module Logical GLSL450 { spirv.module Logical GLSL450 { spirv.func @splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" { - // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>> + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>> %0 = spirv.Constant [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>> spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>> } @@ -122,8 +122,8 @@ spirv.module Logical GLSL450 { spirv.module Logical GLSL450 { spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" { - %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> + %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>> } } @@ -132,8 +132,8 @@ spirv.module Logical GLSL450 { spirv.module Logical GLSL450 { spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" { - %0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>> + %0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>> } } @@ -142,7 +142,7 @@ spirv.module Logical GLSL450 { spirv.module Logical GLSL450 { spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" { - // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>> + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>> %0 = spirv.Constant dense<3.0> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>> spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>> } @@ -152,7 +152,7 @@ spirv.module Logical GLSL450 { spirv.module Logical GLSL450 { spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { - // CHECK: spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32> + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32> %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32> spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> } @@ -160,6 +160,86 @@ spirv.module Logical GLSL450 { // ----- +spirv.module Logical GLSL450 { + spirv.func @array_of_one_i32() -> (!spirv.array<1 x i32>) "None" { + // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant [1 : i32] : !spirv.array<1 x i32> + spirv.ReturnValue %0 : !spirv.array<1 x i32> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @arm_tensor_of_one_i32() -> (!spirv.arm.tensor<1xi32>) "None" { + // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<1xi32> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @non_splat_vector_of_i32() -> (vector<3xi32>) "None" { + // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant dense<[0, 1, 2]> : vector<3xi32> + spirv.ReturnValue %0 : vector<3xi32> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @non_splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" { + // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 3]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>> + spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @array_of_one_f32() -> (!spirv.array<1 x f32>) "None" { + // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant [1.0 : f32] : !spirv.array<1 x f32> + spirv.ReturnValue %0 : !spirv.array<1 x f32> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @arm_tensor_of_one_f32() -> (!spirv.arm.tensor<1xf32>) "None" { + // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant dense<1.0> : !spirv.arm.tensor<1xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<1xf32> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @non_splat_vector_of_f32() -> (vector<3xf32>) "None" { + // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant dense<[0.0, 1.0, 2.0]> : vector<3xf32> + spirv.ReturnValue %0 : vector<3xf32> + } +} + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @non_splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" { + // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 3.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> + spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>> + } +} + +// ----- + spirv.module Logical GLSL450 { spirv.SpecConstant @sc_i32_1 = 1 : i32 @@ -189,4 +269,35 @@ spirv.module Logical GLSL450 { // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_f32 (@sc_f32_1) : !spirv.arm.tensor<3xf32> spirv.SpecConstantComposite @scc_splat_arm_tensor_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32> + + spirv.SpecConstant @sc_i32_2 = 2 : i32 + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_array_of_one_i32 (@sc_i32_1) : !spirv.array<1 x i32> + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_arm_tensor_of_one_i32 (@sc_i32_1) : !spirv.arm.tensor<1xi32> + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_non_splat_vector_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_2) : vector<3 x i32> + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_non_splat_arm_tensor_of_i32 (@sc_i32_2, @sc_i32_1, @sc_i32_1) : !spirv.arm.tensor<3xi32> + + spirv.SpecConstant @sc_f32_2 = 2.0 : f32 + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_array_of_one_f32 (@sc_f32_1) : !spirv.array<1 x f32> + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_arm_tensor_of_one_f32 (@sc_f32_1) : !spirv.arm.tensor<1xf32> + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_non_splat_vector_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_2) : vector<3 x f32> + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_non_splat_arm_tensor_of_f32 (@sc_f32_2, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32> + + // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate + spirv.SpecConstantComposite @scc_struct_of_i32_and_f32 (@sc_i32_1, @sc_i32_1, @sc_f32_1) : !spirv.struct<(i32, i32, f32)> } \ No newline at end of file From d065a5adb971cf7eca74d92d155647971f79c47f Mon Sep 17 00:00:00 2001 From: Mohammadreza Ameri Mahabadian Date: Thu, 17 Jul 2025 09:44:18 +0100 Subject: [PATCH 04/10] Addressing further code review comments Signed-off-by: Mohammadreza Ameri Mahabadian --- .../mlir/Dialect/SPIRV/Transforms/Passes.td | 4 +- .../Dialect/SPIRV/Transforms/CMakeLists.txt | 4 +- ...vertToReplicatedConstantCompositePass.cpp} | 86 +++++++++---------- .../replicated-const-composites.mlir | 86 +------------------ 4 files changed, 46 insertions(+), 134 deletions(-) rename mlir/lib/Dialect/SPIRV/Transforms/{ConversionToReplicatedConstantCompositePass.cpp => ConvertToReplicatedConstantCompositePass.cpp} (65%) diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td index bc1c1f075e09b..a4418085b5ce5 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td @@ -79,8 +79,8 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> { def SPIRVReplicatedConstantCompositePass : Pass<"spirv-convert-to-replicated-const-composite", "spirv::ModuleOp"> { - let summary = "Convert splat composite constants and spec constants to" - "corresponding replicated constant composite ops defined by" + let summary = "Convert splat composite constants and spec constants to " + "corresponding replicated constant composite ops defined by " "SPV_EXT_replicated_composites"; } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt index c675af9d048cc..b947447dad46a 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt @@ -1,6 +1,6 @@ set(LLVM_OPTIONAL_SOURCES CanonicalizeGLPass.cpp - ConversionToReplicatedConstantCompositePass.cpp + ConvertToReplicatedConstantCompositePass.cpp DecorateCompositeTypeLayoutPass.cpp LowerABIAttributesPass.cpp RewriteInsertsPass.cpp @@ -31,7 +31,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion add_mlir_dialect_library(MLIRSPIRVTransforms CanonicalizeGLPass.cpp - ConversionToReplicatedConstantCompositePass.cpp + ConvertToReplicatedConstantCompositePass.cpp DecorateCompositeTypeLayoutPass.cpp LowerABIAttributesPass.cpp RewriteInsertsPass.cpp diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp similarity index 65% rename from mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp rename to mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp index da1777ab42f97..acd66002746aa 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp @@ -1,4 +1,4 @@ -//===- ConversionToReplicatedConstantCompositePass.cpp --------------------===// +//===- ConvertToReplicatedConstantCompositePass.cpp --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -14,21 +14,18 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" -namespace mlir { -namespace spirv { +namespace mlir::spirv { #define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc" -} // namespace spirv -} // namespace mlir - -using namespace mlir; namespace { -Attribute getSplatAttribute(Attribute valueAttr, uint32_t &splatCount) { +static std::pair +getSplatAttributeAndCount(Attribute valueAttr) { Attribute attr; + uint32_t splatCount = 0; if (auto denseAttr = dyn_cast(valueAttr)) { if (denseAttr.isSplat()) { attr = denseAttr.getSplatValue(); @@ -44,30 +41,27 @@ Attribute getSplatAttribute(Attribute valueAttr, uint32_t &splatCount) { if (attr) { if (auto typedAttr = dyn_cast(attr)) { - if (isa(typedAttr.getType())) - if (Attribute newAttr = getSplatAttribute(attr, splatCount)) - attr = newAttr; + if (isa(typedAttr.getType())) { + std::pair newSplatAttrAndCount = + getSplatAttributeAndCount(attr); + if (newSplatAttrAndCount.first) { + return newSplatAttrAndCount; + } + } } else if (isa(attr)) { - if (Attribute newAttr = getSplatAttribute(attr, splatCount)) - attr = newAttr; + std::pair newSplatAttrAndCount = + getSplatAttributeAndCount(attr); + if (newSplatAttrAndCount.first) { + return newSplatAttrAndCount; + } } } - return attr; + return {attr, splatCount}; } -} // namespace - -namespace { -class ConversionToReplicatedConstantCompositePass - : public spirv::impl::SPIRVReplicatedConstantCompositePassBase< - ConversionToReplicatedConstantCompositePass> { -public: - void runOnOperation() override; -}; - -class ConstantOpConversion : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ConstantOpConversion final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spirv::ConstantOp op, PatternRewriter &rewriter) const override { @@ -75,25 +69,25 @@ class ConstantOpConversion : public OpRewritePattern { if (!compositeType) return rewriter.notifyMatchFailure(op, "not a composite constant"); - uint32_t splatCount = 0; - Attribute splatAttr = getSplatAttribute(op.getValue(), splatCount); - if (!splatAttr) + std::pair splatAttrAndCount = + getSplatAttributeAndCount(op.getValue()); + if (!splatAttrAndCount.first) return rewriter.notifyMatchFailure(op, "composite is not splat"); - if (splatCount == 1) + if (splatAttrAndCount.second == 1) return rewriter.notifyMatchFailure(op, "composite has only one constituent"); rewriter.replaceOpWithNewOp( - op, op.getType(), splatAttr); + op, op.getType(), splatAttrAndCount.first); return success(); } }; -class SpecConstantCompositeOpConversion - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct SpecConstantCompositeOpConversion final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op, PatternRewriter &rewriter) const override { @@ -123,15 +117,17 @@ class SpecConstantCompositeOpConversion } }; -void ConversionToReplicatedConstantCompositePass::runOnOperation() { - MLIRContext *context = &getContext(); - RewritePatternSet patterns(context); - patterns.add(context); - patterns.add(context); - - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { - signalPassFailure(); +struct ConvertToReplicatedConstantCompositePass final + : spirv::impl::SPIRVReplicatedConstantCompositePassBase< + ConvertToReplicatedConstantCompositePass> { + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add( + context); + walkAndApplyPatterns(getOperation(), std::move(patterns)); } -} +}; -} // namespace \ No newline at end of file +} // namespace +} // namespace mlir::spirv diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir index 4431f417635b8..b343d0bc73b4f 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --spirv-convert-to-replicated-const-composite --split-input-file --verify-diagnostics %s -o - | FileCheck %s +// RUN: mlir-opt --spirv-convert-to-replicated-const-composite --split-input-file --verify-diagnostics %s | FileCheck %s spirv.module Logical GLSL450 { spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" { @@ -6,211 +6,127 @@ spirv.module Logical GLSL450 { %0 = spirv.Constant dense<2> : vector<3xi32> spirv.ReturnValue %0 : vector<3xi32> } -} - -// ----- -spirv.module Logical GLSL450 { spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" { // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32> %0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32> spirv.ReturnValue %0 : !spirv.array<3 x i32> } -} -// ----- - -spirv.module Logical GLSL450 { spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" { // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>> %0 = spirv.Constant [[3 : i32, 3 : i32, 3 : i32], [3 : i32, 3 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>> spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>> } -} - -// ----- -spirv.module Logical GLSL450 { spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" { // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>> %0 = spirv.Constant [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>> spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>> } -} -// ----- - -spirv.module Logical GLSL450 { spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" { // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>> %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>> spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>> } -} - -// ----- -spirv.module Logical GLSL450 { spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" { // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>> %0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>> spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>> } -} -// ----- - -spirv.module Logical GLSL450 { spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" { // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>> %0 = spirv.Constant dense<3> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>> spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>> } -} - -// ----- -spirv.module Logical GLSL450 { spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32> %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32> spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> } -} -// ----- - -spirv.module Logical GLSL450 { spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" { // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32> %0 = spirv.Constant dense<2.0> : vector<3xf32> spirv.ReturnValue %0 : vector<3xf32> } -} - -// ----- -spirv.module Logical GLSL450 { spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" { // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32> %0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32> spirv.ReturnValue %0 : !spirv.array<3 x f32> } -} -// ----- - -spirv.module Logical GLSL450 { spirv.func @splat_array_of_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" { // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>> %0 = spirv.Constant [[3.0 : f32, 3.0 : f32, 3.0 : f32], [3.0 : f32, 3.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>> spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>> } -} - -// ----- -spirv.module Logical GLSL450 { spirv.func @splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" { // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>> %0 = spirv.Constant [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>> spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>> } -} -// ----- - -spirv.module Logical GLSL450 { spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" { // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>> } -} - -// ----- -spirv.module Logical GLSL450 { spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" { // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>> %0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>> } -} -// ----- - -spirv.module Logical GLSL450 { spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" { // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>> %0 = spirv.Constant dense<3.0> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>> spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>> } -} - -// ----- -spirv.module Logical GLSL450 { spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32> %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32> spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> } -} - -// ----- -spirv.module Logical GLSL450 { spirv.func @array_of_one_i32() -> (!spirv.array<1 x i32>) "None" { // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate %0 = spirv.Constant [1 : i32] : !spirv.array<1 x i32> spirv.ReturnValue %0 : !spirv.array<1 x i32> } -} - -// ----- -spirv.module Logical GLSL450 { spirv.func @arm_tensor_of_one_i32() -> (!spirv.arm.tensor<1xi32>) "None" { // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate %0 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi32> spirv.ReturnValue %0 : !spirv.arm.tensor<1xi32> } -} -// ----- - -spirv.module Logical GLSL450 { spirv.func @non_splat_vector_of_i32() -> (vector<3xi32>) "None" { // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate %0 = spirv.Constant dense<[0, 1, 2]> : vector<3xi32> spirv.ReturnValue %0 : vector<3xi32> } -} -// ----- - -spirv.module Logical GLSL450 { spirv.func @non_splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" { // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 3]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>> spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>> } -} - -// ----- -spirv.module Logical GLSL450 { spirv.func @array_of_one_f32() -> (!spirv.array<1 x f32>) "None" { // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate %0 = spirv.Constant [1.0 : f32] : !spirv.array<1 x f32> spirv.ReturnValue %0 : !spirv.array<1 x f32> } -} -// ----- - -spirv.module Logical GLSL450 { spirv.func @arm_tensor_of_one_f32() -> (!spirv.arm.tensor<1xf32>) "None" { // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate %0 = spirv.Constant dense<1.0> : !spirv.arm.tensor<1xf32> From b718e34d6a35f83c534121075d1459a5af5bd521 Mon Sep 17 00:00:00 2001 From: Mohammadreza Ameri Mahabadian Date: Thu, 17 Jul 2025 11:22:13 +0100 Subject: [PATCH 05/10] Addressing further code review comments Signed-off-by: Mohammadreza Ameri Mahabadian --- ...nvertToReplicatedConstantCompositePass.cpp | 26 +++++++------------ 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp index acd66002746aa..1856f8017ab33 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp @@ -1,4 +1,4 @@ -//===- ConvertToReplicatedConstantCompositePass.cpp --------------------===// +//===- ConvertToReplicatedConstantCompositePass.cpp -----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -40,15 +40,9 @@ getSplatAttributeAndCount(Attribute valueAttr) { } if (attr) { - if (auto typedAttr = dyn_cast(attr)) { - if (isa(typedAttr.getType())) { - std::pair newSplatAttrAndCount = - getSplatAttributeAndCount(attr); - if (newSplatAttrAndCount.first) { - return newSplatAttrAndCount; - } - } - } else if (isa(attr)) { + auto typedAttr = dyn_cast(attr); + if ((typedAttr && isa(typedAttr.getType())) || + isa(attr)) { std::pair newSplatAttrAndCount = getSplatAttributeAndCount(attr); if (newSplatAttrAndCount.first) { @@ -69,17 +63,16 @@ struct ConstantOpConversion final : OpRewritePattern { if (!compositeType) return rewriter.notifyMatchFailure(op, "not a composite constant"); - std::pair splatAttrAndCount = - getSplatAttributeAndCount(op.getValue()); - if (!splatAttrAndCount.first) + auto [splattAttr, splatCount] = getSplatAttributeAndCount(op.getValue()); + if (!splattAttr) return rewriter.notifyMatchFailure(op, "composite is not splat"); - if (splatAttrAndCount.second == 1) + if (splatCount == 1) return rewriter.notifyMatchFailure(op, "composite has only one constituent"); rewriter.replaceOpWithNewOp( - op, op.getType(), splatAttrAndCount.first); + op, op.getType(), splattAttr); return success(); } @@ -104,8 +97,7 @@ struct SpecConstantCompositeOpConversion final std::not_equal_to<>()) == constituents.end())) return rewriter.notifyMatchFailure(op, "composite is not splat"); - auto splatConstituent = - dyn_cast(op.getConstituents()[0]); + auto splatConstituent = dyn_cast(constituents[0]); if (!splatConstituent) return rewriter.notifyMatchFailure( op, "expected flat symbol reference for splat constituent"); From 8234e682f0562f82e63ab02d00db31fc82117d8e Mon Sep 17 00:00:00 2001 From: Mohammadreza Ameri Mahabadian Date: Thu, 17 Jul 2025 11:30:06 +0100 Subject: [PATCH 06/10] Minor typo fix Signed-off-by: Mohammadreza Ameri Mahabadian --- .../Transforms/ConvertToReplicatedConstantCompositePass.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp index 1856f8017ab33..798a405b4df69 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp @@ -63,8 +63,8 @@ struct ConstantOpConversion final : OpRewritePattern { if (!compositeType) return rewriter.notifyMatchFailure(op, "not a composite constant"); - auto [splattAttr, splatCount] = getSplatAttributeAndCount(op.getValue()); - if (!splattAttr) + auto [splatAttr, splatCount] = getSplatAttributeAndCount(op.getValue()); + if (!splatAttr) return rewriter.notifyMatchFailure(op, "composite is not splat"); if (splatCount == 1) @@ -72,7 +72,7 @@ struct ConstantOpConversion final : OpRewritePattern { "composite has only one constituent"); rewriter.replaceOpWithNewOp( - op, op.getType(), splattAttr); + op, op.getType(), splatAttr); return success(); } From 568bccb8140222caa2e701f708fe0c613e805ef8 Mon Sep 17 00:00:00 2001 From: Mohammadreza Ameri Mahabadian Date: Thu, 17 Jul 2025 16:37:22 +0100 Subject: [PATCH 07/10] Addressing further code review comments Signed-off-by: Mohammadreza Ameri Mahabadian --- .../mlir/Dialect/SPIRV/Transforms/Passes.td | 2 +- ...nvertToReplicatedConstantCompositePass.cpp | 30 +++++++------------ .../replicated-const-composites.mlir | 16 +++------- 3 files changed, 16 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td index a4418085b5ce5..2016bea43fc8a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td @@ -78,7 +78,7 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> { } def SPIRVReplicatedConstantCompositePass - : Pass<"spirv-convert-to-replicated-const-composite", "spirv::ModuleOp"> { + : Pass<"spirv-promote-to-replicated-constants", "spirv::ModuleOp"> { let summary = "Convert splat composite constants and spec constants to " "corresponding replicated constant composite ops defined by " "SPV_EXT_replicated_composites"; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp index 798a405b4df69..c4df57072a55d 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp @@ -23,28 +23,22 @@ namespace mlir::spirv { namespace { static std::pair -getSplatAttributeAndCount(Attribute valueAttr) { +getSplatAttrAndNumElements(Attribute valueAttr) { Attribute attr; uint32_t splatCount = 0; - if (auto denseAttr = dyn_cast(valueAttr)) { - if (denseAttr.isSplat()) { - attr = denseAttr.getSplatValue(); - splatCount = denseAttr.size(); - } - } else if (auto arrayAttr = dyn_cast(valueAttr)) { - if (std::adjacent_find(arrayAttr.begin(), arrayAttr.end(), - std::not_equal_to<>()) == arrayAttr.end()) { + if (auto splatAttr = dyn_cast(valueAttr)) { + return {splatAttr.getSplatValue(), splatAttr.size()}; + } + if (auto arrayAttr = dyn_cast(valueAttr)) { + if (llvm::all_equal(arrayAttr)) { attr = arrayAttr[0]; splatCount = arrayAttr.size(); } - } - if (attr) { - auto typedAttr = dyn_cast(attr); - if ((typedAttr && isa(typedAttr.getType())) || - isa(attr)) { + if (attr) { + // Find the inner-most splat value for array of composites std::pair newSplatAttrAndCount = - getSplatAttributeAndCount(attr); + getSplatAttrAndNumElements(attr); if (newSplatAttrAndCount.first) { return newSplatAttrAndCount; } @@ -63,7 +57,7 @@ struct ConstantOpConversion final : OpRewritePattern { if (!compositeType) return rewriter.notifyMatchFailure(op, "not a composite constant"); - auto [splatAttr, splatCount] = getSplatAttributeAndCount(op.getValue()); + auto [splatAttr, splatCount] = getSplatAttrAndNumElements(op.getValue()); if (!splatAttr) return rewriter.notifyMatchFailure(op, "composite is not splat"); @@ -73,7 +67,6 @@ struct ConstantOpConversion final : OpRewritePattern { rewriter.replaceOpWithNewOp( op, op.getType(), splatAttr); - return success(); } }; @@ -93,8 +86,7 @@ struct SpecConstantCompositeOpConversion final return rewriter.notifyMatchFailure(op, "composite has only one consituent"); - if (!(std::adjacent_find(constituents.begin(), constituents.end(), - std::not_equal_to<>()) == constituents.end())) + if (!(llvm::all_equal(constituents))) return rewriter.notifyMatchFailure(op, "composite is not splat"); auto splatConstituent = dyn_cast(constituents[0]); diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir index b343d0bc73b4f..b3a8bd830c668 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir @@ -1,6 +1,6 @@ -// RUN: mlir-opt --spirv-convert-to-replicated-const-composite --split-input-file --verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt --spirv-promote-to-replicated-constants --split-input-file %s | FileCheck %s -spirv.module Logical GLSL450 { +spirv.module Logical GLSL450 requires #spirv.vce { spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" { // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : vector<3xi32> %0 = spirv.Constant dense<2> : vector<3xi32> @@ -132,21 +132,13 @@ spirv.module Logical GLSL450 { %0 = spirv.Constant dense<1.0> : !spirv.arm.tensor<1xf32> spirv.ReturnValue %0 : !spirv.arm.tensor<1xf32> } -} - -// ----- -spirv.module Logical GLSL450 { spirv.func @non_splat_vector_of_f32() -> (vector<3xf32>) "None" { // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate %0 = spirv.Constant dense<[0.0, 1.0, 2.0]> : vector<3xf32> spirv.ReturnValue %0 : vector<3xf32> } -} -// ----- - -spirv.module Logical GLSL450 { spirv.func @non_splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" { // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 3.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> @@ -156,7 +148,7 @@ spirv.module Logical GLSL450 { // ----- -spirv.module Logical GLSL450 { +spirv.module Logical GLSL450 requires #spirv.vce { spirv.SpecConstant @sc_i32_1 = 1 : i32 @@ -216,4 +208,4 @@ spirv.module Logical GLSL450 { // CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate spirv.SpecConstantComposite @scc_struct_of_i32_and_f32 (@sc_i32_1, @sc_i32_1, @sc_f32_1) : !spirv.struct<(i32, i32, f32)> -} \ No newline at end of file +} From ae1eb18744012ec90467c0d2d267f12c9650fb54 Mon Sep 17 00:00:00 2001 From: Mohammadreza Ameri Mahabadian Date: Thu, 17 Jul 2025 22:07:57 +0100 Subject: [PATCH 08/10] Addressing further code review comments Signed-off-by: Mohammadreza Ameri Mahabadian --- ...nvertToReplicatedConstantCompositePass.cpp | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp index c4df57072a55d..8ca615499404b 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp @@ -25,27 +25,24 @@ namespace { static std::pair getSplatAttrAndNumElements(Attribute valueAttr) { Attribute attr; - uint32_t splatCount = 0; + uint32_t numElements = 0; if (auto splatAttr = dyn_cast(valueAttr)) { return {splatAttr.getSplatValue(), splatAttr.size()}; } if (auto arrayAttr = dyn_cast(valueAttr)) { if (llvm::all_equal(arrayAttr)) { attr = arrayAttr[0]; - splatCount = arrayAttr.size(); - } + numElements = arrayAttr.size(); - if (attr) { // Find the inner-most splat value for array of composites - std::pair newSplatAttrAndCount = - getSplatAttrAndNumElements(attr); - if (newSplatAttrAndCount.first) { - return newSplatAttrAndCount; + auto [newAttr, newNumElements] = getSplatAttrAndNumElements(attr); + if (newAttr) { + return {newAttr, numElements * newNumElements}; } } } - return {attr, splatCount}; + return {attr, numElements}; } struct ConstantOpConversion final : OpRewritePattern { @@ -57,16 +54,16 @@ struct ConstantOpConversion final : OpRewritePattern { if (!compositeType) return rewriter.notifyMatchFailure(op, "not a composite constant"); - auto [splatAttr, splatCount] = getSplatAttrAndNumElements(op.getValue()); - if (!splatAttr) + auto [attr, numElements] = getSplatAttrAndNumElements(op.getValue()); + if (!attr) return rewriter.notifyMatchFailure(op, "composite is not splat"); - if (splatCount == 1) + if (numElements == 1) return rewriter.notifyMatchFailure(op, "composite has only one constituent"); rewriter.replaceOpWithNewOp( - op, op.getType(), splatAttr); + op, op.getType(), attr); return success(); } }; @@ -86,7 +83,7 @@ struct SpecConstantCompositeOpConversion final return rewriter.notifyMatchFailure(op, "composite has only one consituent"); - if (!(llvm::all_equal(constituents))) + if (!llvm::all_equal(constituents)) return rewriter.notifyMatchFailure(op, "composite is not splat"); auto splatConstituent = dyn_cast(constituents[0]); From 6ea73e386f002cbd77bd39a3671ab97e44d82c3a Mon Sep 17 00:00:00 2001 From: Mohammadreza Ameri Mahabadian Date: Fri, 18 Jul 2025 12:27:58 +0100 Subject: [PATCH 09/10] Slight change of logic for value type detection Signed-off-by: Mohammadreza Ameri Mahabadian --- ...nvertToReplicatedConstantCompositePass.cpp | 32 ++++++--- .../replicated-const-composites.mlir | 72 +++++++++++++++++++ 2 files changed, 96 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp index 8ca615499404b..faa0165271c60 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp @@ -22,20 +22,39 @@ namespace mlir::spirv { namespace { +static Type getArrayElemType(Attribute attr) { + if (auto typedAttr = dyn_cast(attr)) { + return typedAttr.getType(); + } + + if (auto arrayAttr = dyn_cast(attr)) { + return ArrayType::get(getArrayElemType(arrayAttr[0]), arrayAttr.size()); + } + + return nullptr; +} + static std::pair -getSplatAttrAndNumElements(Attribute valueAttr) { +getSplatAttrAndNumElements(Attribute valueAttr, Type valueType) { Attribute attr; - uint32_t numElements = 0; + uint32_t numElements = 1; + + auto compositeType = dyn_cast_or_null(valueType); + if (!compositeType) + return {nullptr, 1}; + if (auto splatAttr = dyn_cast(valueAttr)) { return {splatAttr.getSplatValue(), splatAttr.size()}; } + if (auto arrayAttr = dyn_cast(valueAttr)) { if (llvm::all_equal(arrayAttr)) { attr = arrayAttr[0]; numElements = arrayAttr.size(); // Find the inner-most splat value for array of composites - auto [newAttr, newNumElements] = getSplatAttrAndNumElements(attr); + auto [newAttr, newNumElements] = + getSplatAttrAndNumElements(attr, getArrayElemType(attr)); if (newAttr) { return {newAttr, numElements * newNumElements}; } @@ -50,11 +69,8 @@ struct ConstantOpConversion final : OpRewritePattern { LogicalResult matchAndRewrite(spirv::ConstantOp op, PatternRewriter &rewriter) const override { - auto compositeType = dyn_cast_or_null(op.getType()); - if (!compositeType) - return rewriter.notifyMatchFailure(op, "not a composite constant"); - - auto [attr, numElements] = getSplatAttrAndNumElements(op.getValue()); + auto [attr, numElements] = + getSplatAttrAndNumElements(op.getValue(), op.getType()); if (!attr) return rewriter.notifyMatchFailure(op, "composite is not splat"); diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir index b3a8bd830c668..56e26eee83ff9 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir @@ -49,6 +49,36 @@ spirv.module Logical GLSL450 requires #spirv.vce } + spirv.func @array_of_splat_array_of_non_splat_vectors_of_i32() -> (!spirv.array<1 x !spirv.array<2 x vector<2xi32>>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<1 x !spirv.array<2 x vector<2xi32>> + %0 = spirv.Constant [[dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<2 x vector<2xi32>>> + spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<2 x vector<2xi32>>> + } + + spirv.func @array_of_one_splat_array_of_vector_of_one_i32() -> !spirv.array<1 x !spirv.array<2 x vector<1xi32>>> "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1> : vector<1xi32>] : !spirv.array<1 x !spirv.array<2 x vector<1xi32> + %cst = spirv.Constant [[dense<1> : vector<1xi32>], [dense<1> : vector<1xi32>]] : !spirv.array<1 x !spirv.array<2 x vector<1xi32>>> + spirv.ReturnValue %cst : !spirv.array<1 x !spirv.array<2 x vector<1xi32>>> + } + + spirv.func @splat_array_of_array_of_one_vector_of_one_i32() -> (!spirv.array<2 x !spirv.array<1 x vector<1xi32>>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1> : vector<1xi32>] : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>> + %0 = spirv.Constant [[dense<1> : vector<1xi32>], [dense<1> : vector<1xi32>]] : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>> + } + + spirv.func @array_of_one_array_of_one_splat_vector_of_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xi32>>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>> + %0 = spirv.Constant [[dense<1> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>> + spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>> + } + + spirv.func @splat_array_of_splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> + %0 = spirv.Constant [[[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]], [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> + } + spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" { // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32> %0 = spirv.Constant dense<2.0> : vector<3xf32> @@ -97,6 +127,36 @@ spirv.module Logical GLSL450 requires #spirv.vce } + spirv.func @array_of_splat_array_of_non_splat_vectors_of_f32() -> (!spirv.array<1 x !spirv.array<2 x vector<2xf32>>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<1 x !spirv.array<2 x vector<2xf32>> + %0 = spirv.Constant [[dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<2 x vector<2xf32>>> + spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<2 x vector<2xf32>>> + } + + spirv.func @array_of_one_splat_array_of_vector_of_one_f32() -> !spirv.array<1 x !spirv.array<2 x vector<1xf32>>> "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1.000000e+00> : vector<1xf32>] : !spirv.array<1 x !spirv.array<2 x vector<1xf32> + %cst = spirv.Constant [[dense<1.0> : vector<1xf32>], [dense<1.0> : vector<1xf32>]] : !spirv.array<1 x !spirv.array<2 x vector<1xf32>>> + spirv.ReturnValue %cst : !spirv.array<1 x !spirv.array<2 x vector<1xf32>>> + } + + spirv.func @splat_array_of_array_of_one_vector_of_one_f32() -> (!spirv.array<2 x !spirv.array<1 x vector<1xf32>>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1.000000e+00> : vector<1xf32>] : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>> + %0 = spirv.Constant [[dense<1.0> : vector<1xf32>], [dense<1.0> : vector<1xf32>]] : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>> + } + + spirv.func @array_of_one_array_of_one_splat_vector_of_f32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xf32>>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>> + %0 = spirv.Constant [[dense<1.0> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>> + spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>> + } + + spirv.func @splat_array_of_splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>) "None" { + // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> + %0 = spirv.Constant [[[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]], [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> + } + spirv.func @array_of_one_i32() -> (!spirv.array<1 x i32>) "None" { // CHECK-NOT: spirv.EXT.ConstantCompositeReplicate %0 = spirv.Constant [1 : i32] : !spirv.array<1 x i32> @@ -144,6 +204,18 @@ spirv.module Logical GLSL450 requires #spirv.vce : vector<2xf32>, dense<[1.0, 3.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>> spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>> } + + spirv.func @array_of_one_array_of_one_non_splat_vector_of_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xi32>>>) "None" { + // CHECK-NOT spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant [[dense<[1, 2]> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>> + spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>> + } + + spirv.func @array_of_one_array_of_one_vector_of_one_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<1xi32>>>) "None" { + // CHECK-NOT spirv.EXT.ConstantCompositeReplicate + %0 = spirv.Constant [[dense<1> : vector<1xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<1xi32>>> + spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<1xi32>>> + } } // ----- From 1584f58d29067c2418e746f05864f0ff43b57964 Mon Sep 17 00:00:00 2001 From: Mohammadreza Ameri Mahabadian Date: Fri, 18 Jul 2025 17:04:52 +0100 Subject: [PATCH 10/10] Addressing further code review comments Signed-off-by: Mohammadreza Ameri Mahabadian --- .../ConvertToReplicatedConstantCompositePass.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp index faa0165271c60..dbbe23aa08b3c 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp @@ -36,9 +36,6 @@ static Type getArrayElemType(Attribute attr) { static std::pair getSplatAttrAndNumElements(Attribute valueAttr, Type valueType) { - Attribute attr; - uint32_t numElements = 1; - auto compositeType = dyn_cast_or_null(valueType); if (!compositeType) return {nullptr, 1}; @@ -49,19 +46,21 @@ getSplatAttrAndNumElements(Attribute valueAttr, Type valueType) { if (auto arrayAttr = dyn_cast(valueAttr)) { if (llvm::all_equal(arrayAttr)) { - attr = arrayAttr[0]; - numElements = arrayAttr.size(); + Attribute attr = arrayAttr[0]; + uint32_t numElements = arrayAttr.size(); // Find the inner-most splat value for array of composites auto [newAttr, newNumElements] = getSplatAttrAndNumElements(attr, getArrayElemType(attr)); if (newAttr) { - return {newAttr, numElements * newNumElements}; + attr = newAttr; + numElements *= newNumElements; } + return {attr, numElements}; } } - return {attr, numElements}; + return {nullptr, 1}; } struct ConstantOpConversion final : OpRewritePattern {