Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HLSL][SPIRV][DXIL] Implement WaveActiveMax intrinsic #123428

Merged
merged 6 commits into from
Jan 28, 2025

Conversation

adam-yang
Copy link
Contributor

      - link builtin in hlsl_intrinsics
      - add codegen for spirv intrinsic and two directx intrinsics to retain
        signedness information of the operands in CGBuiltin.cpp
      - add semantic analysis in SemaHLSL.cpp
      - add lowering of spirv intrinsic to spirv backend in
        SPIRVInstructionSelector.cpp
      - add lowering of directx intrinsics to WaveActiveOp dxil op in
    DXIL.td

      - add test cases to illustrate passespendent pr merges.

Resolves #99170

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

Copy link

github-actions bot commented Jan 18, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@adam-yang adam-yang marked this pull request as ready for review January 21, 2025 18:43
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen backend:DirectX HLSL HLSL Language Support backend:SPIR-V llvm:ir labels Jan 21, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 21, 2025

@llvm/pr-subscribers-backend-x86
@llvm/pr-subscribers-backend-directx
@llvm/pr-subscribers-hlsl
@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-clang-codegen

@llvm/pr-subscribers-clang

Author: Adam Yang (adam-yang)

Changes
      - link builtin in hlsl_intrinsics
      - add codegen for spirv intrinsic and two directx intrinsics to retain
        signedness information of the operands in CGBuiltin.cpp
      - add semantic analysis in SemaHLSL.cpp
      - add lowering of spirv intrinsic to spirv backend in
        SPIRVInstructionSelector.cpp
      - add lowering of directx intrinsics to WaveActiveOp dxil op in
    DXIL.td

      - add test cases to illustrate passespendent pr merges.

Resolves #99170


Patch is 22.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123428.diff

12 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+36)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+1)
  • (added) clang/test/CodeGenHLSL/builtins/WaveActiveMax.hlsl (+46)
  • (added) clang/test/SemaHLSL/BuiltIns/WaveActiveMax-errors.hlsl (+29)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+2)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+2)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+6)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp (+2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+33)
  • (added) llvm/test/CodeGen/DirectX/WaveActiveMax.ll (+143)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveMax.ll (+55)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index bbf4886b5cf053..a7f17ec4c8ef8f 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4795,6 +4795,12 @@ def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "unsigned int(bool)";
 }
 
+def HLSLWaveActiveMax : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_wave_active_max"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void (...)";
+}
+
 def HLSLWaveActiveSum : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_wave_active_sum"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index b80833fd91884d..57d8eec24a5a0b 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -19203,6 +19203,25 @@ static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch,
   }
 }
 
+// Return wave active sum that corresponds to the QT scalar type
+static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
+                                               CGHLSLRuntime &RT, QualType QT) {
+  switch (Arch) {
+  case llvm::Triple::spirv:
+    if (QT->isUnsignedIntegerType())
+      return llvm::Intrinsic::spv_wave_reduce_umax;
+    return llvm::Intrinsic::spv_wave_reduce_max;
+  case llvm::Triple::dxil: {
+    if (QT->isUnsignedIntegerType())
+      return llvm::Intrinsic::dx_wave_reduce_umax;
+    return llvm::Intrinsic::dx_wave_reduce_max;
+  }
+  default:
+    llvm_unreachable("Intrinsic WaveActiveMax"
+                     " not supported by target architecture");
+  }
+}
+
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
                                             const CallExpr *E,
                                             ReturnValueSlot ReturnValue) {
@@ -19532,6 +19551,23 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
                                                      /*AssumeConvergent=*/true),
                            ArrayRef{OpExpr}, "hlsl.wave.active.sum");
   }
+  case Builtin::BI__builtin_hlsl_wave_active_max: {
+    // Due to the use of variadic arguments, explicitly retreive argument
+    Value *OpExpr = EmitScalarExpr(E->getArg(0));
+    llvm::FunctionType *FT = llvm::FunctionType::get(
+        OpExpr->getType(), ArrayRef{OpExpr->getType()}, false);
+    Intrinsic::ID IID = getWaveActiveMaxIntrinsic(
+        getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
+        E->getArg(0)->getType());
+
+    // Get overloaded name
+    std::string Name =
+        Intrinsic::getName(IID, ArrayRef{OpExpr->getType()}, &CGM.getModule());
+    return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, Name, {},
+                                                     /*Local=*/false,
+                                                     /*AssumeConvergent=*/true),
+                           ArrayRef{OpExpr}, "hlsl.wave.active.max");
+  }
   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
     // We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in
     // defined in SPIRVBuiltins.td. So instead we manually get the matching name
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 5001883003ee2d..fd7435caeb000e 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2183,6 +2183,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     TheCall->setType(ArgTyA);
     break;
   }
+  case Builtin::BI__builtin_hlsl_wave_active_max:
   case Builtin::BI__builtin_hlsl_wave_active_sum: {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveMax.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveMax.hlsl
new file mode 100644
index 00000000000000..7891cfc1989afc
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveMax.hlsl
@@ -0,0 +1,46 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_int
+int test_int(int expr) {
+  // CHECK-SPIRV:  %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.reduce.max.i32([[TY]] %[[#]])
+  // CHECK-DXIL:  %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.reduce.max.i32([[TY]] %[[#]])
+  // CHECK:  ret [[TY]] %[[RET]]
+  return WaveActiveMax(expr);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.max.i32([[TY]]) #[[#attr:]]
+// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.reduce.max.i32([[TY]]) #[[#attr:]]
+
+// CHECK-LABEL: test_uint64_t
+uint64_t test_uint64_t(uint64_t expr) {
+  // CHECK-SPIRV:  %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.reduce.umax.i64([[TY]] %[[#]])
+  // CHECK-DXIL:  %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.reduce.umax.i64([[TY]] %[[#]])
+  // CHECK:  ret [[TY]] %[[RET]]
+  return WaveActiveMax(expr);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.umax.i64([[TY]]) #[[#attr:]]
+// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.reduce.umax.i64([[TY]]) #[[#attr:]]
+
+// Test basic lowering to runtime function call with array and float value.
+
+// CHECK-LABEL: test_floatv4
+float4 test_floatv4(float4 expr) {
+  // CHECK-SPIRV:  %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn spir_func [[TY1:.*]] @llvm.spv.wave.reduce.max.v4f32([[TY1]] %[[#]]
+  // CHECK-DXIL:  %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn [[TY1:.*]] @llvm.dx.wave.reduce.max.v4f32([[TY1]] %[[#]])
+  // CHECK:  ret [[TY1]] %[[RET1]]
+  return WaveActiveMax(expr);
+}
+
+// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.reduce.max.v4f32([[TY1]]) #[[#attr]]
+// CHECK-SPIRV: declare spir_func [[TY1]] @llvm.spv.wave.reduce.max.v4f32([[TY1]]) #[[#attr]]
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
+
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveMax-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveMax-errors.hlsl
new file mode 100644
index 00000000000000..e077a40ba5165b
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveMax-errors.hlsl
@@ -0,0 +1,29 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+int test_too_few_arg() {
+  return __builtin_hlsl_wave_active_max();
+  // expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+  return __builtin_hlsl_wave_active_max(p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+bool test_expr_bool_type_check(bool p0) {
+  return __builtin_hlsl_wave_active_max(p0);
+  // expected-error@-1 {{invalid operand of type 'bool'}}
+}
+
+bool2 test_expr_bool_vec_type_check(bool2 p0) {
+  return __builtin_hlsl_wave_active_max(p0);
+  // expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}}
+}
+
+struct S { float f; };
+
+S test_expr_struct_type_check(S p0) {
+  return __builtin_hlsl_wave_active_max(p0);
+  // expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}}
+}
+
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index f21948697c8a6e..beed84b144cec4 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -105,6 +105,8 @@ def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1
 def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
+def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
+def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_reduce_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_reduce_usum : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index be337dbccaf8a9..38910ee263ee36 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -91,6 +91,8 @@ let TargetPrefix = "spv" in {
   def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
   def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
   def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
+  def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
+  def int_spv_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
   def int_spv_wave_reduce_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
   def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
   def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 4b20a64cb07226..887fdb5b8bb21e 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -1008,6 +1008,12 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> {
     IntrinSelect<
         int_dx_wave_reduce_usum,
         [ IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Sum>, IntrinArgI8<SignedOpKind_Unsigned> ]>,
+    IntrinSelect<
+        int_dx_wave_reduce_max,
+        [ IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Max>, IntrinArgI8<SignedOpKind_Signed> ]>,
+    IntrinSelect<
+        int_dx_wave_reduce_umax,
+        [ IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Max>, IntrinArgI8<SignedOpKind_Unsigned> ]>,
   ];
 
   let arguments = [OverloadTy, Int8Ty, Int8Ty];
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 4e6e01bc5edbc7..ba656dc7371407 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -40,6 +40,8 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
   switch (ID) {
   case Intrinsic::dx_frac:
   case Intrinsic::dx_rsqrt:
+  case Intrinsic::dx_wave_reduce_max:
+  case Intrinsic::dx_wave_reduce_umax:
   case Intrinsic::dx_wave_reduce_sum:
   case Intrinsic::dx_wave_reduce_usum:
   case Intrinsic::dx_wave_readlane:
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index f5409c27d6ea3d..6cdbff45d41769 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -215,6 +215,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectDot4AddPackedExpansion(Register ResVReg, const SPIRVType *ResType,
                                     MachineInstr &I) const;
 
+  bool selectWaveReduceMax(Register ResVReg, const SPIRVType *ResType,
+                           MachineInstr &I, bool IsUnsigned) const;
+
   bool selectWaveReduceSum(Register ResVReg, const SPIRVType *ResType,
                            MachineInstr &I) const;
 
@@ -2132,6 +2135,32 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
   return Result;
 }
 
+bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
+                                                   const SPIRVType *ResType,
+                                                   MachineInstr &I, bool IsUnsigned) const {
+  assert(I.getNumOperands() == 3);
+  assert(I.getOperand(2).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+  Register InputRegister = I.getOperand(2).getReg();
+  SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister);
+
+  if (!InputType)
+    report_fatal_error("Input Type could not be determined.");
+
+  SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+  // Retreive the operation to use based on input type
+  bool IsFloatTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeFloat);
+  auto IntegerOpcodeType = IsUnsigned ? SPIRV::OpGroupNonUniformUMax : SPIRV::OpGroupNonUniformSMax;
+  auto Opcode =
+      IsFloatTy ? SPIRV::OpGroupNonUniformFMax : IntegerOpcodeType;
+  return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
+      .addImm(SPIRV::GroupOperation::Reduce)
+      .addUse(I.getOperand(2).getReg());
+}
+
 bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg,
                                                    const SPIRVType *ResType,
                                                    MachineInstr &I) const {
@@ -3086,6 +3115,10 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny);
   case Intrinsic::spv_wave_is_first_lane:
     return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformElect);
+  case Intrinsic::spv_wave_reduce_umax:
+    return selectWaveReduceMax(ResVReg, ResType, I, /*IsUnsigned*/true);
+  case Intrinsic::spv_wave_reduce_max:
+    return selectWaveReduceMax(ResVReg, ResType, I, /*IsUnsigned*/false);
   case Intrinsic::spv_wave_reduce_sum:
     return selectWaveReduceSum(ResVReg, ResType, I);
   case Intrinsic::spv_wave_readlane:
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveMax.ll b/llvm/test/CodeGen/DirectX/WaveActiveMax.ll
new file mode 100644
index 00000000000000..0e84018865f3f6
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/WaveActiveMax.ll
@@ -0,0 +1,143 @@
+; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s
+
+; Test that for scalar values, WaveActiveMax maps down to the DirectX op
+
+define noundef half @wave_active_max_half(half noundef %expr) {
+entry:
+; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr, i8 3, i8 0)
+  %ret = call half @llvm.dx.wave.reduce.max.f16(half %expr)
+  ret half %ret
+}
+
+define noundef float @wave_active_max_float(float noundef %expr) {
+entry:
+; CHECK: call float @dx.op.waveActiveOp.f32(i32 119, float %expr, i8 3, i8 0)
+  %ret = call float @llvm.dx.wave.reduce.max.f32(float %expr)
+  ret float %ret
+}
+
+define noundef double @wave_active_max_double(double noundef %expr) {
+entry:
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr, i8 3, i8 0)
+  %ret = call double @llvm.dx.wave.reduce.max.f64(double %expr)
+  ret double %ret
+}
+
+define noundef i16 @wave_active_max_i16(i16 noundef %expr) {
+entry:
+; CHECK: call i16 @dx.op.waveActiveOp.i16(i32 119, i16 %expr, i8 3, i8 0)
+  %ret = call i16 @llvm.dx.wave.reduce.max.i16(i16 %expr)
+  ret i16 %ret
+}
+
+define noundef i32 @wave_active_max_i32(i32 noundef %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr, i8 3, i8 0)
+  %ret = call i32 @llvm.dx.wave.reduce.max.i32(i32 %expr)
+  ret i32 %ret
+}
+
+define noundef i64 @wave_active_max_i64(i64 noundef %expr) {
+entry:
+; CHECK: call i64 @dx.op.waveActiveOp.i64(i32 119, i64 %expr, i8 3, i8 0)
+  %ret = call i64 @llvm.dx.wave.reduce.max.i64(i64 %expr)
+  ret i64 %ret
+}
+
+define noundef i16 @wave_active_umax_i16(i16 noundef %expr) {
+entry:
+; CHECK: call i16 @dx.op.waveActiveOp.i16(i32 119, i16 %expr, i8 3, i8 1)
+  %ret = call i16 @llvm.dx.wave.reduce.umax.i16(i16 %expr)
+  ret i16 %ret
+}
+
+define noundef i32 @wave_active_umax_i32(i32 noundef %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr, i8 3, i8 1)
+  %ret = call i32 @llvm.dx.wave.reduce.umax.i32(i32 %expr)
+  ret i32 %ret
+}
+
+define noundef i64 @wave_active_umax_i64(i64 noundef %expr) {
+entry:
+; CHECK: call i64 @dx.op.waveActiveOp.i64(i32 119, i64 %expr, i8 3, i8 1)
+  %ret = call i64 @llvm.dx.wave.reduce.umax.i64(i64 %expr)
+  ret i64 %ret
+}
+
+declare half @llvm.dx.wave.reduce.max.f16(half)
+declare float @llvm.dx.wave.reduce.max.f32(float)
+declare double @llvm.dx.wave.reduce.max.f64(double)
+
+declare i16 @llvm.dx.wave.reduce.max.i16(i16)
+declare i32 @llvm.dx.wave.reduce.max.i32(i32)
+declare i64 @llvm.dx.wave.reduce.max.i64(i64)
+
+declare i16 @llvm.dx.wave.reduce.umax.i16(i16)
+declare i32 @llvm.dx.wave.reduce.umax.i32(i32)
+declare i64 @llvm.dx.wave.reduce.umax.i64(i64)
+
+; Test that for vector values, WaveActiveMax scalarizes and maps down to the
+; DirectX op
+
+define noundef <2 x half> @wave_active_max_v2half(<2 x half> noundef %expr) {
+entry:
+; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr.i0, i8 3, i8 0)
+; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr.i1, i8 3, i8 0)
+  %ret = call <2 x half> @llvm.dx.wave.reduce.max.v2f16(<2 x half> %expr)
+  ret <2 x half> %ret
+}
+
+define noundef <3 x i32> @wave_active_max_v3i32(<3 x i32> noundef %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i0, i8 3, i8 0)
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i1, i8 3, i8 0)
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i2, i8 3, i8 0)
+  %ret = call <3 x i32> @llvm.dx.wave.reduce.max.v3i32(<3 x i32> %expr)
+  ret <3 x i32> %ret
+}
+
+define noundef <4 x double> @wave_active_max_v4f64(<4 x double> noundef %expr) {
+entry:
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i0, i8 3, i8 0)
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i1, i8 3, i8 0)
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i2, i8 3, i8 0)
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i3, i8 3, i8 0)
+  %ret = call <4 x double> @llvm.dx.wave.reduce.max.v4f64(<4 x double> %expr)
+  ret <4 x double> %ret
+}
+
+declare <2 x half> @llvm.dx.wave.reduce.max.v2f16(<2 x half>)
+declare <3 x i32> @llvm.dx.wave.reduce.max.v3i32(<3 x i32>)
+declare <4 x double> @llvm.dx.wave.reduce.max.v4f64(<4 x double>)
+
+define noundef <2 x i16> @wave_active_umax_v2i16(<2 x i16> noundef %expr) {
+entry:
+; CHECK: call i16 @dx.op.waveActiveOp.i16(i32 119, i16 %expr.i0, i8 3, i8 1)
+; CHECK: call i16 @dx.op.waveActiveOp.i16(i32 119, i16 %expr.i1, i8 3, i8 1)
+  %ret = call <2 x i16> @llvm.dx.wave.reduce.umax.v2f16(<2 x i16> %expr)
+  ret <2 x i16> %ret
+}
+
+define noundef <3 x i32> @wave_active_umax_v3i32(<3 x i32> noundef %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i0, i8 3, i8 1)
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i1, i8 3, i8 1)
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i2, i8 3, i8 1)
+  %ret = call <3 x i32> @llvm.dx.wave.reduce.umax.v3i32(<3 x i32> %expr)
+  ret <3 x i32> %ret
+}
+
+define noundef <4 x i64> @wave_active_umax_v4f64(<4 x i64> noundef %expr) {
+entry:
+; CHECK: call i64 @dx.op.waveActiveOp.i64(i32 119, i64 %expr.i0, i8 3, i8 1)
+; CHECK: call i64 @dx.op.waveActiveOp.i64(i32 119, i64 %expr.i1, i8 3, i8 1)
+; CHECK: call i64 @dx.op.waveActiveOp.i64(i32 119, i64 %expr.i2, i8 3, i8 1)
+; CHECK: call i64 @dx.op.waveActiveOp.i64(i32 119, i64 %expr.i3, i8 3, i8 1)
+  %ret = call <4 x i64> @llvm.dx.wave.reduce.umax.v4f64(<4 x i64> %expr)
+  ret <4 x i64> %ret
+}
+
+declare <2 x i16> @llvm.dx.wave.reduce.umax.v2f16(<2 x i16>)
+declare <3 x i32> @llvm.dx.wave.reduce.umax.v3i32(<3 x i32>)
+declare <4 x i64> @llvm.dx.wave.reduce.umax.v4f64(<4 x i64>)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveMax.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveMax.ll
new file mode 100644
i...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 21, 2025

@llvm/pr-subscribers-backend-spir-v

Author: Adam Yang (adam-yang)

Changes
      - link builtin in hlsl_intrinsics
      - add codegen for spirv intrinsic and two directx intrinsics to retain
        signedness information of the operands in CGBuiltin.cpp
      - add semantic analysis in SemaHLSL.cpp
      - add lowering of spirv intrinsic to spirv backend in
        SPIRVInstructionSelector.cpp
      - add lowering of directx intrinsics to WaveActiveOp dxil op in
    DXIL.td

      - add test cases to illustrate passespendent pr merges.

Resolves #99170


Patch is 22.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123428.diff

12 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+36)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+1)
  • (added) clang/test/CodeGenHLSL/builtins/WaveActiveMax.hlsl (+46)
  • (added) clang/test/SemaHLSL/BuiltIns/WaveActiveMax-errors.hlsl (+29)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+2)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+2)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+6)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp (+2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+33)
  • (added) llvm/test/CodeGen/DirectX/WaveActiveMax.ll (+143)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveMax.ll (+55)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index bbf4886b5cf053..a7f17ec4c8ef8f 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4795,6 +4795,12 @@ def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "unsigned int(bool)";
 }
 
+def HLSLWaveActiveMax : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_wave_active_max"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void (...)";
+}
+
 def HLSLWaveActiveSum : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_wave_active_sum"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index b80833fd91884d..57d8eec24a5a0b 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -19203,6 +19203,25 @@ static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch,
   }
 }
 
+// Return wave active sum that corresponds to the QT scalar type
+static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
+                                               CGHLSLRuntime &RT, QualType QT) {
+  switch (Arch) {
+  case llvm::Triple::spirv:
+    if (QT->isUnsignedIntegerType())
+      return llvm::Intrinsic::spv_wave_reduce_umax;
+    return llvm::Intrinsic::spv_wave_reduce_max;
+  case llvm::Triple::dxil: {
+    if (QT->isUnsignedIntegerType())
+      return llvm::Intrinsic::dx_wave_reduce_umax;
+    return llvm::Intrinsic::dx_wave_reduce_max;
+  }
+  default:
+    llvm_unreachable("Intrinsic WaveActiveMax"
+                     " not supported by target architecture");
+  }
+}
+
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
                                             const CallExpr *E,
                                             ReturnValueSlot ReturnValue) {
@@ -19532,6 +19551,23 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
                                                      /*AssumeConvergent=*/true),
                            ArrayRef{OpExpr}, "hlsl.wave.active.sum");
   }
+  case Builtin::BI__builtin_hlsl_wave_active_max: {
+    // Due to the use of variadic arguments, explicitly retreive argument
+    Value *OpExpr = EmitScalarExpr(E->getArg(0));
+    llvm::FunctionType *FT = llvm::FunctionType::get(
+        OpExpr->getType(), ArrayRef{OpExpr->getType()}, false);
+    Intrinsic::ID IID = getWaveActiveMaxIntrinsic(
+        getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
+        E->getArg(0)->getType());
+
+    // Get overloaded name
+    std::string Name =
+        Intrinsic::getName(IID, ArrayRef{OpExpr->getType()}, &CGM.getModule());
+    return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, Name, {},
+                                                     /*Local=*/false,
+                                                     /*AssumeConvergent=*/true),
+                           ArrayRef{OpExpr}, "hlsl.wave.active.max");
+  }
   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
     // We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in
     // defined in SPIRVBuiltins.td. So instead we manually get the matching name
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 5001883003ee2d..fd7435caeb000e 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2183,6 +2183,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     TheCall->setType(ArgTyA);
     break;
   }
+  case Builtin::BI__builtin_hlsl_wave_active_max:
   case Builtin::BI__builtin_hlsl_wave_active_sum: {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveMax.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveMax.hlsl
new file mode 100644
index 00000000000000..7891cfc1989afc
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveMax.hlsl
@@ -0,0 +1,46 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_int
+int test_int(int expr) {
+  // CHECK-SPIRV:  %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.reduce.max.i32([[TY]] %[[#]])
+  // CHECK-DXIL:  %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.reduce.max.i32([[TY]] %[[#]])
+  // CHECK:  ret [[TY]] %[[RET]]
+  return WaveActiveMax(expr);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.max.i32([[TY]]) #[[#attr:]]
+// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.reduce.max.i32([[TY]]) #[[#attr:]]
+
+// CHECK-LABEL: test_uint64_t
+uint64_t test_uint64_t(uint64_t expr) {
+  // CHECK-SPIRV:  %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.reduce.umax.i64([[TY]] %[[#]])
+  // CHECK-DXIL:  %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.reduce.umax.i64([[TY]] %[[#]])
+  // CHECK:  ret [[TY]] %[[RET]]
+  return WaveActiveMax(expr);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.umax.i64([[TY]]) #[[#attr:]]
+// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.reduce.umax.i64([[TY]]) #[[#attr:]]
+
+// Test basic lowering to runtime function call with array and float value.
+
+// CHECK-LABEL: test_floatv4
+float4 test_floatv4(float4 expr) {
+  // CHECK-SPIRV:  %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn spir_func [[TY1:.*]] @llvm.spv.wave.reduce.max.v4f32([[TY1]] %[[#]]
+  // CHECK-DXIL:  %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn [[TY1:.*]] @llvm.dx.wave.reduce.max.v4f32([[TY1]] %[[#]])
+  // CHECK:  ret [[TY1]] %[[RET1]]
+  return WaveActiveMax(expr);
+}
+
+// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.reduce.max.v4f32([[TY1]]) #[[#attr]]
+// CHECK-SPIRV: declare spir_func [[TY1]] @llvm.spv.wave.reduce.max.v4f32([[TY1]]) #[[#attr]]
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
+
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveMax-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveMax-errors.hlsl
new file mode 100644
index 00000000000000..e077a40ba5165b
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveMax-errors.hlsl
@@ -0,0 +1,29 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+int test_too_few_arg() {
+  return __builtin_hlsl_wave_active_max();
+  // expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+  return __builtin_hlsl_wave_active_max(p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+bool test_expr_bool_type_check(bool p0) {
+  return __builtin_hlsl_wave_active_max(p0);
+  // expected-error@-1 {{invalid operand of type 'bool'}}
+}
+
+bool2 test_expr_bool_vec_type_check(bool2 p0) {
+  return __builtin_hlsl_wave_active_max(p0);
+  // expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}}
+}
+
+struct S { float f; };
+
+S test_expr_struct_type_check(S p0) {
+  return __builtin_hlsl_wave_active_max(p0);
+  // expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}}
+}
+
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index f21948697c8a6e..beed84b144cec4 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -105,6 +105,8 @@ def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1
 def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
+def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
+def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_reduce_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_reduce_usum : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index be337dbccaf8a9..38910ee263ee36 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -91,6 +91,8 @@ let TargetPrefix = "spv" in {
   def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
   def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
   def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
+  def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
+  def int_spv_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
   def int_spv_wave_reduce_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
   def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
   def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 4b20a64cb07226..887fdb5b8bb21e 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -1008,6 +1008,12 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> {
     IntrinSelect<
         int_dx_wave_reduce_usum,
         [ IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Sum>, IntrinArgI8<SignedOpKind_Unsigned> ]>,
+    IntrinSelect<
+        int_dx_wave_reduce_max,
+        [ IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Max>, IntrinArgI8<SignedOpKind_Signed> ]>,
+    IntrinSelect<
+        int_dx_wave_reduce_umax,
+        [ IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Max>, IntrinArgI8<SignedOpKind_Unsigned> ]>,
   ];
 
   let arguments = [OverloadTy, Int8Ty, Int8Ty];
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 4e6e01bc5edbc7..ba656dc7371407 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -40,6 +40,8 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
   switch (ID) {
   case Intrinsic::dx_frac:
   case Intrinsic::dx_rsqrt:
+  case Intrinsic::dx_wave_reduce_max:
+  case Intrinsic::dx_wave_reduce_umax:
   case Intrinsic::dx_wave_reduce_sum:
   case Intrinsic::dx_wave_reduce_usum:
   case Intrinsic::dx_wave_readlane:
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index f5409c27d6ea3d..6cdbff45d41769 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -215,6 +215,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectDot4AddPackedExpansion(Register ResVReg, const SPIRVType *ResType,
                                     MachineInstr &I) const;
 
+  bool selectWaveReduceMax(Register ResVReg, const SPIRVType *ResType,
+                           MachineInstr &I, bool IsUnsigned) const;
+
   bool selectWaveReduceSum(Register ResVReg, const SPIRVType *ResType,
                            MachineInstr &I) const;
 
@@ -2132,6 +2135,32 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
   return Result;
 }
 
+bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
+                                                   const SPIRVType *ResType,
+                                                   MachineInstr &I, bool IsUnsigned) const {
+  assert(I.getNumOperands() == 3);
+  assert(I.getOperand(2).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+  Register InputRegister = I.getOperand(2).getReg();
+  SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister);
+
+  if (!InputType)
+    report_fatal_error("Input Type could not be determined.");
+
+  SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+  // Retreive the operation to use based on input type
+  bool IsFloatTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeFloat);
+  auto IntegerOpcodeType = IsUnsigned ? SPIRV::OpGroupNonUniformUMax : SPIRV::OpGroupNonUniformSMax;
+  auto Opcode =
+      IsFloatTy ? SPIRV::OpGroupNonUniformFMax : IntegerOpcodeType;
+  return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
+      .addImm(SPIRV::GroupOperation::Reduce)
+      .addUse(I.getOperand(2).getReg());
+}
+
 bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg,
                                                    const SPIRVType *ResType,
                                                    MachineInstr &I) const {
@@ -3086,6 +3115,10 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny);
   case Intrinsic::spv_wave_is_first_lane:
     return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformElect);
+  case Intrinsic::spv_wave_reduce_umax:
+    return selectWaveReduceMax(ResVReg, ResType, I, /*IsUnsigned*/true);
+  case Intrinsic::spv_wave_reduce_max:
+    return selectWaveReduceMax(ResVReg, ResType, I, /*IsUnsigned*/false);
   case Intrinsic::spv_wave_reduce_sum:
     return selectWaveReduceSum(ResVReg, ResType, I);
   case Intrinsic::spv_wave_readlane:
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveMax.ll b/llvm/test/CodeGen/DirectX/WaveActiveMax.ll
new file mode 100644
index 00000000000000..0e84018865f3f6
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/WaveActiveMax.ll
@@ -0,0 +1,143 @@
+; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s
+
+; Test that for scalar values, WaveActiveMax maps down to the DirectX op
+
+define noundef half @wave_active_max_half(half noundef %expr) {
+entry:
+; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr, i8 3, i8 0)
+  %ret = call half @llvm.dx.wave.reduce.max.f16(half %expr)
+  ret half %ret
+}
+
+define noundef float @wave_active_max_float(float noundef %expr) {
+entry:
+; CHECK: call float @dx.op.waveActiveOp.f32(i32 119, float %expr, i8 3, i8 0)
+  %ret = call float @llvm.dx.wave.reduce.max.f32(float %expr)
+  ret float %ret
+}
+
+define noundef double @wave_active_max_double(double noundef %expr) {
+entry:
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr, i8 3, i8 0)
+  %ret = call double @llvm.dx.wave.reduce.max.f64(double %expr)
+  ret double %ret
+}
+
+define noundef i16 @wave_active_max_i16(i16 noundef %expr) {
+entry:
+; CHECK: call i16 @dx.op.waveActiveOp.i16(i32 119, i16 %expr, i8 3, i8 0)
+  %ret = call i16 @llvm.dx.wave.reduce.max.i16(i16 %expr)
+  ret i16 %ret
+}
+
+define noundef i32 @wave_active_max_i32(i32 noundef %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr, i8 3, i8 0)
+  %ret = call i32 @llvm.dx.wave.reduce.max.i32(i32 %expr)
+  ret i32 %ret
+}
+
+define noundef i64 @wave_active_max_i64(i64 noundef %expr) {
+entry:
+; CHECK: call i64 @dx.op.waveActiveOp.i64(i32 119, i64 %expr, i8 3, i8 0)
+  %ret = call i64 @llvm.dx.wave.reduce.max.i64(i64 %expr)
+  ret i64 %ret
+}
+
+define noundef i16 @wave_active_umax_i16(i16 noundef %expr) {
+entry:
+; CHECK: call i16 @dx.op.waveActiveOp.i16(i32 119, i16 %expr, i8 3, i8 1)
+  %ret = call i16 @llvm.dx.wave.reduce.umax.i16(i16 %expr)
+  ret i16 %ret
+}
+
+define noundef i32 @wave_active_umax_i32(i32 noundef %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr, i8 3, i8 1)
+  %ret = call i32 @llvm.dx.wave.reduce.umax.i32(i32 %expr)
+  ret i32 %ret
+}
+
+define noundef i64 @wave_active_umax_i64(i64 noundef %expr) {
+entry:
+; CHECK: call i64 @dx.op.waveActiveOp.i64(i32 119, i64 %expr, i8 3, i8 1)
+  %ret = call i64 @llvm.dx.wave.reduce.umax.i64(i64 %expr)
+  ret i64 %ret
+}
+
+declare half @llvm.dx.wave.reduce.max.f16(half)
+declare float @llvm.dx.wave.reduce.max.f32(float)
+declare double @llvm.dx.wave.reduce.max.f64(double)
+
+declare i16 @llvm.dx.wave.reduce.max.i16(i16)
+declare i32 @llvm.dx.wave.reduce.max.i32(i32)
+declare i64 @llvm.dx.wave.reduce.max.i64(i64)
+
+declare i16 @llvm.dx.wave.reduce.umax.i16(i16)
+declare i32 @llvm.dx.wave.reduce.umax.i32(i32)
+declare i64 @llvm.dx.wave.reduce.umax.i64(i64)
+
+; Test that for vector values, WaveActiveMax scalarizes and maps down to the
+; DirectX op
+
+define noundef <2 x half> @wave_active_max_v2half(<2 x half> noundef %expr) {
+entry:
+; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr.i0, i8 3, i8 0)
+; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr.i1, i8 3, i8 0)
+  %ret = call <2 x half> @llvm.dx.wave.reduce.max.v2f16(<2 x half> %expr)
+  ret <2 x half> %ret
+}
+
+define noundef <3 x i32> @wave_active_max_v3i32(<3 x i32> noundef %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i0, i8 3, i8 0)
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i1, i8 3, i8 0)
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i2, i8 3, i8 0)
+  %ret = call <3 x i32> @llvm.dx.wave.reduce.max.v3i32(<3 x i32> %expr)
+  ret <3 x i32> %ret
+}
+
+define noundef <4 x double> @wave_active_max_v4f64(<4 x double> noundef %expr) {
+entry:
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i0, i8 3, i8 0)
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i1, i8 3, i8 0)
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i2, i8 3, i8 0)
+; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i3, i8 3, i8 0)
+  %ret = call <4 x double> @llvm.dx.wave.reduce.max.v4f64(<4 x double> %expr)
+  ret <4 x double> %ret
+}
+
+declare <2 x half> @llvm.dx.wave.reduce.max.v2f16(<2 x half>)
+declare <3 x i32> @llvm.dx.wave.reduce.max.v3i32(<3 x i32>)
+declare <4 x double> @llvm.dx.wave.reduce.max.v4f64(<4 x double>)
+
+define noundef <2 x i16> @wave_active_umax_v2i16(<2 x i16> noundef %expr) {
+entry:
+; CHECK: call i16 @dx.op.waveActiveOp.i16(i32 119, i16 %expr.i0, i8 3, i8 1)
+; CHECK: call i16 @dx.op.waveActiveOp.i16(i32 119, i16 %expr.i1, i8 3, i8 1)
+  %ret = call <2 x i16> @llvm.dx.wave.reduce.umax.v2f16(<2 x i16> %expr)
+  ret <2 x i16> %ret
+}
+
+define noundef <3 x i32> @wave_active_umax_v3i32(<3 x i32> noundef %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i0, i8 3, i8 1)
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i1, i8 3, i8 1)
+; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i2, i8 3, i8 1)
+  %ret = call <3 x i32> @llvm.dx.wave.reduce.umax.v3i32(<3 x i32> %expr)
+  ret <3 x i32> %ret
+}
+
+define noundef <4 x i64> @wave_active_umax_v4f64(<4 x i64> noundef %expr) {
+entry:
+; CHECK: call i64 @dx.op.waveActiveOp.i64(i32 119, i64 %expr.i0, i8 3, i8 1)
+; CHECK: call i64 @dx.op.waveActiveOp.i64(i32 119, i64 %expr.i1, i8 3, i8 1)
+; CHECK: call i64 @dx.op.waveActiveOp.i64(i32 119, i64 %expr.i2, i8 3, i8 1)
+; CHECK: call i64 @dx.op.waveActiveOp.i64(i32 119, i64 %expr.i3, i8 3, i8 1)
+  %ret = call <4 x i64> @llvm.dx.wave.reduce.umax.v4f64(<4 x i64> %expr)
+  ret <4 x i64> %ret
+}
+
+declare <2 x i16> @llvm.dx.wave.reduce.umax.v2f16(<2 x i16>)
+declare <3 x i32> @llvm.dx.wave.reduce.umax.v3i32(<3 x i32>)
+declare <4 x i64> @llvm.dx.wave.reduce.umax.v4f64(<4 x i64>)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveMax.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveMax.ll
new file mode 100644
i...
[truncated]

@llvmbot llvmbot added backend:X86 clang:headers Headers provided by Clang, e.g. for intrinsics labels Jan 21, 2025
Copy link
Contributor

@inbelic inbelic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a testcase touch up and an optional consideration.

llvm/test/CodeGen/DirectX/WaveActiveMax.ll Outdated Show resolved Hide resolved
clang/lib/CodeGen/CGBuiltin.cpp Show resolved Hide resolved
Copy link
Contributor

@s-perron s-perron left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a couple small touchups.

@adam-yang adam-yang merged commit aab25f2 into llvm:main Jan 28, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX backend:SPIR-V backend:X86 clang:codegen clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

Implement the WaveActiveMax HLSL Function
5 participants