Skip to content

Commit cf1d916

Browse files
committed
[naga spv] Split workgroup and subgroup memory semantics in Control Barriers
1 parent 8e60726 commit cf1d916

11 files changed

+153
-42
lines changed

naga/src/back/spv/writer.rs

+6
Original file line numberDiff line numberDiff line change
@@ -1644,6 +1644,8 @@ impl Writer {
16441644
pub(super) fn write_control_barrier(&mut self, flags: crate::Barrier, block: &mut Block) {
16451645
let memory_scope = if flags.contains(crate::Barrier::STORAGE) {
16461646
spirv::Scope::Device
1647+
} else if flags.contains(crate::Barrier::SUB_GROUP) {
1648+
spirv::Scope::Subgroup
16471649
} else {
16481650
spirv::Scope::Workgroup
16491651
};
@@ -1656,6 +1658,10 @@ impl Writer {
16561658
spirv::MemorySemantics::WORKGROUP_MEMORY,
16571659
flags.contains(crate::Barrier::WORK_GROUP),
16581660
);
1661+
semantics.set(
1662+
spirv::MemorySemantics::SUBGROUP_MEMORY,
1663+
flags.contains(crate::Barrier::SUB_GROUP),
1664+
);
16591665
semantics.set(
16601666
spirv::MemorySemantics::IMAGE_MEMORY,
16611667
flags.contains(crate::Barrier::TEXTURE),

naga/src/front/spv/mod.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -3850,19 +3850,21 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
38503850
let semantics = resolve_constant(ctx.gctx(), &semantics_const.inner)
38513851
.ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?;
38523852

3853-
if exec_scope == spirv::Scope::Workgroup as u32 {
3853+
if exec_scope == spirv::Scope::Workgroup as u32
3854+
|| exec_scope == spirv::Scope::Subgroup as u32
3855+
{
38543856
let mut flags = crate::Barrier::empty();
38553857
flags.set(
38563858
crate::Barrier::STORAGE,
38573859
semantics & spirv::MemorySemantics::UNIFORM_MEMORY.bits() != 0,
38583860
);
38593861
flags.set(
38603862
crate::Barrier::WORK_GROUP,
3861-
semantics
3862-
& (spirv::MemorySemantics::SUBGROUP_MEMORY
3863-
| spirv::MemorySemantics::WORKGROUP_MEMORY)
3864-
.bits()
3865-
!= 0,
3863+
semantics & (spirv::MemorySemantics::WORKGROUP_MEMORY).bits() != 0,
3864+
);
3865+
flags.set(
3866+
crate::Barrier::SUB_GROUP,
3867+
semantics & spirv::MemorySemantics::SUBGROUP_MEMORY.bits() != 0,
38663868
);
38673869
flags.set(
38683870
crate::Barrier::TEXTURE,
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
; SPIR-V
2+
; Version: 1.5
3+
; Generator: Google rspirv; 0
4+
; Bound: 14
5+
; Schema: 0
6+
OpCapability Shader
7+
OpMemoryModel Logical Simple
8+
OpEntryPoint GLCompute %1 "main"
9+
OpExecutionMode %1 LocalSize 64 1 1
10+
%void = OpTypeVoid
11+
%6 = OpTypeFunction %void
12+
%uint = OpTypeInt 32 0
13+
%uint_3 = OpConstant %uint 3
14+
%uint_136 = OpConstant %uint 136
15+
%1 = OpFunction %void None %6
16+
%13 = OpLabel
17+
OpMemoryBarrier %uint_3 %uint_136
18+
OpControlBarrier %uint_3 %uint_3 %uint_136
19+
OpReturn
20+
OpFunctionEnd
+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
god_mode = true
2+
targets = "WGSL | SPIRV | GLSL | METAL"
3+
4+
[msl]
5+
lang_version = [2, 0]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#version 310 es
2+
3+
precision highp float;
4+
precision highp int;
5+
6+
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
7+
8+
9+
void function() {
10+
subgroupMemoryBarrier();
11+
barrier();
12+
subgroupMemoryBarrier();
13+
barrier();
14+
return;
15+
}
16+
17+
void main() {
18+
function();
19+
}
20+

naga/tests/out/glsl/spv-subgroup-operations-s.main.Compute.glsl

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ uint global_3 = 0u;
2020
void function() {
2121
uint _e5 = global_2;
2222
uint _e6 = global_3;
23+
barrier();
2324
uvec4 _e9 = subgroupBallot(((_e6 & 1u) == 1u));
2425
uvec4 _e10 = subgroupBallot(true);
2526
bool _e12 = subgroupAll((_e6 != 0u));
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// language: metal2.0
2+
#include <metal_stdlib>
3+
#include <simd/simd.h>
4+
5+
using metal::uint;
6+
7+
8+
void function(
9+
) {
10+
metal::simdgroup_barrier(metal::mem_flags::mem_threadgroup);
11+
metal::simdgroup_barrier(metal::mem_flags::mem_threadgroup);
12+
return;
13+
}
14+
15+
kernel void main_(
16+
) {
17+
function();
18+
}

naga/tests/out/msl/spv-subgroup-operations-s.msl

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ void function(
1111
) {
1212
uint _e5 = global_2;
1313
uint _e6 = global_3;
14+
metal::threadgroup_barrier(metal::mem_flags::mem_none);
1415
metal::uint4 unnamed = metal::uint4((uint64_t)metal::simd_ballot((_e6 & 1u) == 1u), 0, 0, 0);
1516
metal::uint4 unnamed_1 = metal::uint4((uint64_t)metal::simd_ballot(true), 0, 0, 0);
1617
bool unnamed_2 = metal::simd_all(_e6 != 0u);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
; SPIR-V
2+
; Version: 1.1
3+
; Generator: rspirv
4+
; Bound: 14
5+
OpCapability Shader
6+
%1 = OpExtInstImport "GLSL.std.450"
7+
OpMemoryModel Logical GLSL450
8+
OpEntryPoint GLCompute %11 "main"
9+
OpExecutionMode %11 LocalSize 64 1 1
10+
%2 = OpTypeVoid
11+
%5 = OpTypeFunction %2
12+
%8 = OpTypeInt 32 0
13+
%7 = OpConstant %8 3
14+
%9 = OpConstant %8 136
15+
%4 = OpFunction %2 None %5
16+
%3 = OpLabel
17+
OpBranch %6
18+
%6 = OpLabel
19+
OpMemoryBarrier %7 %9
20+
OpControlBarrier %7 %7 %9
21+
OpReturn
22+
OpFunctionEnd
23+
%11 = OpFunction %2 None %5
24+
%10 = OpLabel
25+
OpBranch %12
26+
%12 = OpLabel
27+
%13 = OpFunctionCall %2 %4
28+
OpReturn
29+
OpFunctionEnd
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
; SPIR-V
22
; Version: 1.3
33
; Generator: rspirv
4-
; Bound: 58
4+
; Bound: 57
55
OpCapability Shader
66
OpCapability GroupNonUniform
77
OpCapability GroupNonUniformBallot
@@ -33,10 +33,9 @@ OpDecorate %15 BuiltIn SubgroupLocalInvocationId
3333
%20 = OpConstant %3 0
3434
%21 = OpConstant %3 4
3535
%23 = OpConstant %3 3
36-
%24 = OpConstant %3 2
37-
%25 = OpConstant %3 8
38-
%28 = OpTypeVector %3 4
39-
%30 = OpConstantTrue %5
36+
%24 = OpConstant %3 136
37+
%27 = OpTypeVector %3 4
38+
%29 = OpConstantTrue %5
4039
%17 = OpFunction %2 None %18
4140
%6 = OpLabel
4241
%10 = OpLoad %3 %8
@@ -46,36 +45,36 @@ OpDecorate %15 BuiltIn SubgroupLocalInvocationId
4645
%16 = OpLoad %3 %15
4746
OpBranch %22
4847
%22 = OpLabel
49-
OpControlBarrier %23 %24 %25
50-
%26 = OpBitwiseAnd %3 %16 %19
51-
%27 = OpIEqual %5 %26 %19
52-
%29 = OpGroupNonUniformBallot %28 %23 %27
53-
%31 = OpGroupNonUniformBallot %28 %23 %30
54-
%32 = OpINotEqual %5 %16 %20
55-
%33 = OpGroupNonUniformAll %5 %23 %32
56-
%34 = OpIEqual %5 %16 %20
57-
%35 = OpGroupNonUniformAny %5 %23 %34
58-
%36 = OpGroupNonUniformIAdd %3 %23 Reduce %16
59-
%37 = OpGroupNonUniformIMul %3 %23 Reduce %16
60-
%38 = OpGroupNonUniformUMin %3 %23 Reduce %16
61-
%39 = OpGroupNonUniformUMax %3 %23 Reduce %16
62-
%40 = OpGroupNonUniformBitwiseAnd %3 %23 Reduce %16
63-
%41 = OpGroupNonUniformBitwiseOr %3 %23 Reduce %16
64-
%42 = OpGroupNonUniformBitwiseXor %3 %23 Reduce %16
65-
%43 = OpGroupNonUniformIAdd %3 %23 ExclusiveScan %16
66-
%44 = OpGroupNonUniformIMul %3 %23 ExclusiveScan %16
67-
%45 = OpGroupNonUniformIAdd %3 %23 InclusiveScan %16
68-
%46 = OpGroupNonUniformIMul %3 %23 InclusiveScan %16
69-
%47 = OpGroupNonUniformBroadcastFirst %3 %23 %16
70-
%48 = OpGroupNonUniformShuffle %3 %23 %16 %21
71-
%49 = OpCompositeExtract %3 %7 1
72-
%50 = OpISub %3 %49 %19
73-
%51 = OpISub %3 %50 %16
74-
%52 = OpGroupNonUniformShuffle %3 %23 %16 %51
75-
%53 = OpGroupNonUniformShuffleDown %3 %23 %16 %19
76-
%54 = OpGroupNonUniformShuffleUp %3 %23 %16 %19
77-
%55 = OpCompositeExtract %3 %7 1
78-
%56 = OpISub %3 %55 %19
79-
%57 = OpGroupNonUniformShuffleXor %3 %23 %16 %56
48+
OpControlBarrier %23 %23 %24
49+
%25 = OpBitwiseAnd %3 %16 %19
50+
%26 = OpIEqual %5 %25 %19
51+
%28 = OpGroupNonUniformBallot %27 %23 %26
52+
%30 = OpGroupNonUniformBallot %27 %23 %29
53+
%31 = OpINotEqual %5 %16 %20
54+
%32 = OpGroupNonUniformAll %5 %23 %31
55+
%33 = OpIEqual %5 %16 %20
56+
%34 = OpGroupNonUniformAny %5 %23 %33
57+
%35 = OpGroupNonUniformIAdd %3 %23 Reduce %16
58+
%36 = OpGroupNonUniformIMul %3 %23 Reduce %16
59+
%37 = OpGroupNonUniformUMin %3 %23 Reduce %16
60+
%38 = OpGroupNonUniformUMax %3 %23 Reduce %16
61+
%39 = OpGroupNonUniformBitwiseAnd %3 %23 Reduce %16
62+
%40 = OpGroupNonUniformBitwiseOr %3 %23 Reduce %16
63+
%41 = OpGroupNonUniformBitwiseXor %3 %23 Reduce %16
64+
%42 = OpGroupNonUniformIAdd %3 %23 ExclusiveScan %16
65+
%43 = OpGroupNonUniformIMul %3 %23 ExclusiveScan %16
66+
%44 = OpGroupNonUniformIAdd %3 %23 InclusiveScan %16
67+
%45 = OpGroupNonUniformIMul %3 %23 InclusiveScan %16
68+
%46 = OpGroupNonUniformBroadcastFirst %3 %23 %16
69+
%47 = OpGroupNonUniformShuffle %3 %23 %16 %21
70+
%48 = OpCompositeExtract %3 %7 1
71+
%49 = OpISub %3 %48 %19
72+
%50 = OpISub %3 %49 %16
73+
%51 = OpGroupNonUniformShuffle %3 %23 %16 %50
74+
%52 = OpGroupNonUniformShuffleDown %3 %23 %16 %19
75+
%53 = OpGroupNonUniformShuffleUp %3 %23 %16 %19
76+
%54 = OpCompositeExtract %3 %7 1
77+
%55 = OpISub %3 %54 %19
78+
%56 = OpGroupNonUniformShuffleXor %3 %23 %16 %55
8079
OpReturn
8180
OpFunctionEnd
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
fn function() {
2+
subgroupBarrier();
3+
subgroupBarrier();
4+
return;
5+
}
6+
7+
@compute @workgroup_size(64, 1, 1)
8+
fn main() {
9+
function();
10+
}

0 commit comments

Comments
 (0)