Skip to content

[mlir][spirv] Add conversion pass to rewrite splat constant composite… #148910

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

mahabadm
Copy link
Contributor

…s to replicated form

This adds a new SPIR-V dialect-level conversion pass ConversionToReplicatedConstantCompositePass. This pass looks for splat composite spirv.Constant or spirv.SpecConstantComposite and rewrites them into spirv.EXT.ConstantCompositeReplicate or spirv.EXT.SpecConstantCompositeReplicate, respectively.

…s to replicated form

This adds a new SPIR-V dialect-level conversion pass `ConversionToReplicatedConstantCompositePass`.
This pass looks for splat composite `spirv.Constant` or `spirv.SpecConstantComposite` and rewrites them into `spirv.EXT.ConstantCompositeReplicate` or `spirv.EXT.SpecConstantCompositeReplicate`, respectively.

Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Jul 15, 2025

@llvm/pr-subscribers-mlir

Author: Mohammadreza Ameri Mahabadian (mahabadm)

Changes

…s to replicated form

This adds a new SPIR-V dialect-level conversion pass ConversionToReplicatedConstantCompositePass. This pass looks for splat composite spirv.Constant or spirv.SpecConstantComposite and rewrites them into spirv.EXT.ConstantCompositeReplicate or spirv.EXT.SpecConstantCompositeReplicate, respectively.


Full diff: https://github.com/llvm/llvm-project/pull/148910.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td (+7)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt (+2)
  • (added) mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp (+135)
  • (added) mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir (+192)
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
index 2d9befe78001d..3c04db8396367 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
@@ -77,4 +77,11 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> {
                 "and replacing with supported ones";
 }
 
+def SPIRVReplicatedConstantCompositePass
+    : Pass<"spirv-replicated-const-composite", "spirv::ModuleOp"> {
+  let summary = "Convert splat composite constants and spec constants to"
+                "corresponding replicated constant composite ops defined by"
+                "SPV_EXT_replicated_composites";
+}
+
 #endif // MLIR_DIALECT_SPIRV_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index 68e0206e30a59..c675af9d048cc 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 set(LLVM_OPTIONAL_SOURCES
   CanonicalizeGLPass.cpp
+  ConversionToReplicatedConstantCompositePass.cpp
   DecorateCompositeTypeLayoutPass.cpp
   LowerABIAttributesPass.cpp
   RewriteInsertsPass.cpp
@@ -30,6 +31,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion
 
 add_mlir_dialect_library(MLIRSPIRVTransforms
   CanonicalizeGLPass.cpp
+  ConversionToReplicatedConstantCompositePass.cpp
   DecorateCompositeTypeLayoutPass.cpp
   LowerABIAttributesPass.cpp
   RewriteInsertsPass.cpp
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
new file mode 100644
index 0000000000000..530a0f4aa67f5
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
@@ -0,0 +1,135 @@
+//===- ConversionToReplicatedConstantCompositePass.cpp
+//---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert a splat composite spirv.Constant and
+// spirv.SpecConstantComposite to spirv.EXT.ConstantCompositeReplicate and
+// spirv.EXT.SpecConstantCompositeReplicate respectively.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace spirv {
+#define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
+} // namespace spirv
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+Attribute getSplatAttribute(Attribute valueAttr, uint32_t splatCount) {
+  Attribute attr;
+  if (auto denseAttr = dyn_cast<DenseElementsAttr>(valueAttr)) {
+    if (denseAttr.isSplat()) {
+      attr = denseAttr.getSplatValue<Attribute>();
+      splatCount = denseAttr.size();
+    }
+  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+    if (std::adjacent_find(arrayAttr.begin(), arrayAttr.end(),
+                           std::not_equal_to<>()) == arrayAttr.end()) {
+      attr = arrayAttr[0];
+      splatCount = arrayAttr.size();
+    }
+  }
+
+  if (attr) {
+    if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+      if (isa<spirv::CompositeType>(typedAttr.getType()))
+        if (Attribute newAttr = getSplatAttribute(attr, splatCount))
+          attr = newAttr;
+    } else if (isa<ArrayAttr>(attr)) {
+      if (Attribute newAttr = getSplatAttribute(attr, splatCount))
+        attr = newAttr;
+    }
+  }
+
+  return attr;
+}
+
+} // namespace
+
+namespace {
+class ConversionToReplicatedConstantCompositePass
+    : public spirv::impl::SPIRVReplicatedConstantCompositePassBase<
+          ConversionToReplicatedConstantCompositePass> {
+public:
+  void runOnOperation() override;
+};
+
+class ConstantOpConversion : public OpRewritePattern<spirv::ConstantOp> {
+  using OpRewritePattern<spirv::ConstantOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::ConstantOp op,
+                                PatternRewriter &rewriter) const override {
+    auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
+    if (!compositeType)
+      return rewriter.notifyMatchFailure(op, "not a composite constant");
+
+    uint32_t splatCount = 0;
+    Attribute splatAttr = getSplatAttribute(op.getValue(), splatCount);
+    if (!splatAttr)
+      return rewriter.notifyMatchFailure(op, "composite is not splat");
+
+    if (splatCount == 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "composite has only one consituent");
+
+    rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
+        op, op.getType(), splatAttr);
+
+    return success();
+  }
+};
+
+class SpecConstantCompositeOpConversion
+    : public OpRewritePattern<spirv::SpecConstantCompositeOp> {
+  using OpRewritePattern<spirv::SpecConstantCompositeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
+    if (!compositeType)
+      return rewriter.notifyMatchFailure(op, "not a composite constant");
+
+    auto constituents = op.getConstituents();
+    if (constituents.size() == 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "composite has only one consituent");
+
+    if (!(std::adjacent_find(constituents.begin(), constituents.end(),
+                             std::not_equal_to<>()) == constituents.end()))
+      return rewriter.notifyMatchFailure(op, "composite is not splat");
+
+    auto splatConstituent =
+        dyn_cast<FlatSymbolRefAttr>(op.getConstituents()[0]);
+
+    rewriter.replaceOpWithNewOp<spirv::EXTSpecConstantCompositeReplicateOp>(
+        op, TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent);
+
+    return success();
+  }
+};
+
+void ConversionToReplicatedConstantCompositePass::runOnOperation() {
+  MLIRContext *context = &getContext();
+  RewritePatternSet patterns(context);
+  patterns.add<ConstantOpConversion>(context);
+  patterns.add<SpecConstantCompositeOpConversion>(context);
+
+  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+    signalPassFailure();
+  }
+}
+
+} // namespace
\ No newline at end of file
diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
new file mode 100644
index 0000000000000..f8cd4bb256bfe
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
@@ -0,0 +1,192 @@
+// RUN: mlir-opt -spirv-replicated-const-composite -split-input-file -verify-diagnostics %s -o - | FileCheck %s
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" {
+    %0 = spirv.Constant dense<2> : vector<3xi32>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : vector<3xi32>
+    spirv.ReturnValue %0 : vector<3xi32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
+    %0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
+    spirv.ReturnValue %0 : !spirv.array<3 x i32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+    %0 = spirv.Constant [[3 : i32, 3 : i32, 3 : i32], [3 : i32, 3 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+    %0 = spirv.Constant [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
+    %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
+    %0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+    %0 = spirv.Constant dense<3> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32>
+    %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" {
+    %0 = spirv.Constant dense<2.0> : vector<3xf32>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32>
+    spirv.ReturnValue %0 : vector<3xf32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" {
+    %0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32>
+    spirv.ReturnValue %0 : !spirv.array<3 x f32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+    %0 = spirv.Constant [[3.0 : f32, 3.0 : f32, 3.0 : f32], [3.0 : f32, 3.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+    %0 = spirv.Constant [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
+    %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
+    %0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+    %0 = spirv.Constant dense<3.0> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32>
+    %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
+
+  spirv.SpecConstant @sc_i32_1 = 1 : i32
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+  spirv.SpecConstantComposite @scc_splat_array_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.array<3 x i32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+  spirv.SpecConstantComposite @scc_splat_struct_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_i32 (@sc_i32_1) : vector<3xi32>
+  spirv.SpecConstantComposite @scc_splat_vector_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : vector<3 x i32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_i32 (@sc_i32_1) : !spirv.arm.tensor<3xi32>
+  spirv.SpecConstantComposite @scc_splat_arm_tensor_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.arm.tensor<3xi32>
+
+  spirv.SpecConstant @sc_f32_1 = 1.0 : f32
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_f32 (@sc_f32_1) : !spirv.array<3 x f32>
+  spirv.SpecConstantComposite @scc_splat_array_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.array<3 x f32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_f32 (@sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+  spirv.SpecConstantComposite @scc_splat_struct_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_f32 (@sc_f32_1) : vector<3xf32>
+  spirv.SpecConstantComposite @scc_splat_vector_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : vector<3 x f32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_f32 (@sc_f32_1) : !spirv.arm.tensor<3xf32>
+  spirv.SpecConstantComposite @scc_splat_arm_tensor_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32>
+}
\ No newline at end of file

@llvmbot
Copy link
Member

llvmbot commented Jul 15, 2025

@llvm/pr-subscribers-mlir-spirv

Author: Mohammadreza Ameri Mahabadian (mahabadm)

Changes

…s to replicated form

This adds a new SPIR-V dialect-level conversion pass ConversionToReplicatedConstantCompositePass. This pass looks for splat composite spirv.Constant or spirv.SpecConstantComposite and rewrites them into spirv.EXT.ConstantCompositeReplicate or spirv.EXT.SpecConstantCompositeReplicate, respectively.


Full diff: https://github.com/llvm/llvm-project/pull/148910.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td (+7)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt (+2)
  • (added) mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp (+135)
  • (added) mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir (+192)
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
index 2d9befe78001d..3c04db8396367 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
@@ -77,4 +77,11 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> {
                 "and replacing with supported ones";
 }
 
+def SPIRVReplicatedConstantCompositePass
+    : Pass<"spirv-replicated-const-composite", "spirv::ModuleOp"> {
+  let summary = "Convert splat composite constants and spec constants to"
+                "corresponding replicated constant composite ops defined by"
+                "SPV_EXT_replicated_composites";
+}
+
 #endif // MLIR_DIALECT_SPIRV_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index 68e0206e30a59..c675af9d048cc 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 set(LLVM_OPTIONAL_SOURCES
   CanonicalizeGLPass.cpp
+  ConversionToReplicatedConstantCompositePass.cpp
   DecorateCompositeTypeLayoutPass.cpp
   LowerABIAttributesPass.cpp
   RewriteInsertsPass.cpp
@@ -30,6 +31,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion
 
 add_mlir_dialect_library(MLIRSPIRVTransforms
   CanonicalizeGLPass.cpp
+  ConversionToReplicatedConstantCompositePass.cpp
   DecorateCompositeTypeLayoutPass.cpp
   LowerABIAttributesPass.cpp
   RewriteInsertsPass.cpp
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
new file mode 100644
index 0000000000000..530a0f4aa67f5
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp
@@ -0,0 +1,135 @@
+//===- ConversionToReplicatedConstantCompositePass.cpp
+//---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert a splat composite spirv.Constant and
+// spirv.SpecConstantComposite to spirv.EXT.ConstantCompositeReplicate and
+// spirv.EXT.SpecConstantCompositeReplicate respectively.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace spirv {
+#define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
+} // namespace spirv
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+Attribute getSplatAttribute(Attribute valueAttr, uint32_t splatCount) {
+  Attribute attr;
+  if (auto denseAttr = dyn_cast<DenseElementsAttr>(valueAttr)) {
+    if (denseAttr.isSplat()) {
+      attr = denseAttr.getSplatValue<Attribute>();
+      splatCount = denseAttr.size();
+    }
+  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+    if (std::adjacent_find(arrayAttr.begin(), arrayAttr.end(),
+                           std::not_equal_to<>()) == arrayAttr.end()) {
+      attr = arrayAttr[0];
+      splatCount = arrayAttr.size();
+    }
+  }
+
+  if (attr) {
+    if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+      if (isa<spirv::CompositeType>(typedAttr.getType()))
+        if (Attribute newAttr = getSplatAttribute(attr, splatCount))
+          attr = newAttr;
+    } else if (isa<ArrayAttr>(attr)) {
+      if (Attribute newAttr = getSplatAttribute(attr, splatCount))
+        attr = newAttr;
+    }
+  }
+
+  return attr;
+}
+
+} // namespace
+
+namespace {
+class ConversionToReplicatedConstantCompositePass
+    : public spirv::impl::SPIRVReplicatedConstantCompositePassBase<
+          ConversionToReplicatedConstantCompositePass> {
+public:
+  void runOnOperation() override;
+};
+
+class ConstantOpConversion : public OpRewritePattern<spirv::ConstantOp> {
+  using OpRewritePattern<spirv::ConstantOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::ConstantOp op,
+                                PatternRewriter &rewriter) const override {
+    auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
+    if (!compositeType)
+      return rewriter.notifyMatchFailure(op, "not a composite constant");
+
+    uint32_t splatCount = 0;
+    Attribute splatAttr = getSplatAttribute(op.getValue(), splatCount);
+    if (!splatAttr)
+      return rewriter.notifyMatchFailure(op, "composite is not splat");
+
+    if (splatCount == 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "composite has only one consituent");
+
+    rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
+        op, op.getType(), splatAttr);
+
+    return success();
+  }
+};
+
+class SpecConstantCompositeOpConversion
+    : public OpRewritePattern<spirv::SpecConstantCompositeOp> {
+  using OpRewritePattern<spirv::SpecConstantCompositeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
+    if (!compositeType)
+      return rewriter.notifyMatchFailure(op, "not a composite constant");
+
+    auto constituents = op.getConstituents();
+    if (constituents.size() == 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "composite has only one consituent");
+
+    if (!(std::adjacent_find(constituents.begin(), constituents.end(),
+                             std::not_equal_to<>()) == constituents.end()))
+      return rewriter.notifyMatchFailure(op, "composite is not splat");
+
+    auto splatConstituent =
+        dyn_cast<FlatSymbolRefAttr>(op.getConstituents()[0]);
+
+    rewriter.replaceOpWithNewOp<spirv::EXTSpecConstantCompositeReplicateOp>(
+        op, TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent);
+
+    return success();
+  }
+};
+
+void ConversionToReplicatedConstantCompositePass::runOnOperation() {
+  MLIRContext *context = &getContext();
+  RewritePatternSet patterns(context);
+  patterns.add<ConstantOpConversion>(context);
+  patterns.add<SpecConstantCompositeOpConversion>(context);
+
+  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+    signalPassFailure();
+  }
+}
+
+} // namespace
\ No newline at end of file
diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
new file mode 100644
index 0000000000000..f8cd4bb256bfe
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
@@ -0,0 +1,192 @@
+// RUN: mlir-opt -spirv-replicated-const-composite -split-input-file -verify-diagnostics %s -o - | FileCheck %s
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" {
+    %0 = spirv.Constant dense<2> : vector<3xi32>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : vector<3xi32>
+    spirv.ReturnValue %0 : vector<3xi32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
+    %0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
+    spirv.ReturnValue %0 : !spirv.array<3 x i32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+    %0 = spirv.Constant [[3 : i32, 3 : i32, 3 : i32], [3 : i32, 3 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+    %0 = spirv.Constant [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
+    %0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
+    %0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
+    %0 = spirv.Constant dense<3> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32>
+    %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" {
+    %0 = spirv.Constant dense<2.0> : vector<3xf32>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32>
+    spirv.ReturnValue %0 : vector<3xf32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" {
+    %0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32>
+    spirv.ReturnValue %0 : !spirv.array<3 x f32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+    %0 = spirv.Constant [[3.0 : f32, 3.0 : f32, 3.0 : f32], [3.0 : f32, 3.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+    %0 = spirv.Constant [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
+    %0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
+    %0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
+    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
+    %0 = spirv.Constant dense<3.0> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32>
+    %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+  }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
+
+  spirv.SpecConstant @sc_i32_1 = 1 : i32
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
+  spirv.SpecConstantComposite @scc_splat_array_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.array<3 x i32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+  spirv.SpecConstantComposite @scc_splat_struct_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.struct<(i32, i32, i32)>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_i32 (@sc_i32_1) : vector<3xi32>
+  spirv.SpecConstantComposite @scc_splat_vector_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : vector<3 x i32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_i32 (@sc_i32_1) : !spirv.arm.tensor<3xi32>
+  spirv.SpecConstantComposite @scc_splat_arm_tensor_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.arm.tensor<3xi32>
+
+  spirv.SpecConstant @sc_f32_1 = 1.0 : f32
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_f32 (@sc_f32_1) : !spirv.array<3 x f32>
+  spirv.SpecConstantComposite @scc_splat_array_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.array<3 x f32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_f32 (@sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+  spirv.SpecConstantComposite @scc_splat_struct_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.struct<(f32, f32, f32)>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_f32 (@sc_f32_1) : vector<3xf32>
+  spirv.SpecConstantComposite @scc_splat_vector_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : vector<3 x f32>
+
+  // CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_f32 (@sc_f32_1) : !spirv.arm.tensor<3xf32>
+  spirv.SpecConstantComposite @scc_splat_arm_tensor_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32>
+}
\ No newline at end of file

@kuhar kuhar requested review from Hardcode84 and IgWod-IMG July 15, 2025 18:01
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two more nits, but LGTM in general. Thanks for addressing the feedback.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked at the logic yet, left some comments related to code organization first

Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
mahabadm added 2 commits July 17, 2025 11:22
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's all from me. Thank you!

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, just some remaining coding style issues

Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
mahabadm added 2 commits July 17, 2025 22:09
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@mahabadm
Copy link
Contributor Author

@kuhar In hindsight, I realized that it is better to move composite type check into getSplatAttrAndNumElements as you suggested earlier. While doing that I noticed an issue with array elements type detection which is now fixed. Also to test that I have now included several more tests (including negative tests). Thank you.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % nit. Thanks for the changes.

Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@kuhar kuhar merged commit 10518c7 into llvm:main Jul 18, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants