Skip to content

Commit 2d09ab8

Browse files
authored
Revert vulkan changes from D76646172 fixup patch
Differential Revision: D76737404 Pull Request resolved: #11727
1 parent 5f789c0 commit 2d09ab8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+3169
-152
lines changed

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 103 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -56,52 +56,97 @@
5656
TYPE_MAPPINGS: Dict[str, Any] = {
5757
"IMAGE_T": {
5858
3: {
59+
"double": "image3D",
5960
"float": "image3D",
6061
"half": "image3D",
61-
"int": "iimage3D",
62-
"uint": "uimage3D",
62+
# integer dtypes
6363
"int8": "iimage3D",
6464
"uint8": "uimage3D",
65+
"int16": "iimage3D",
66+
"uint16": "uimage3D",
67+
"int32": "iimage3D",
68+
"uint32": "uimage3D",
69+
"int64": "iimage3D",
70+
"uint64": "uimage3D",
71+
# common dtype aliases
6572
"bool": "uimage3D",
73+
"int": "iimage3D",
74+
"uint": "uimage3D",
6675
},
6776
2: {
77+
"double": "image2D",
6878
"float": "image2D",
6979
"half": "image2D",
70-
"int": "iimage2D",
71-
"uint": "uimage2D",
80+
# integer dtypes
7281
"int8": "iimage2D",
7382
"uint8": "uimage2D",
83+
"int16": "iimage2D",
84+
"uint16": "uimage2D",
85+
"int32": "iimage2D",
86+
"uint32": "uimage2D",
87+
"int64": "iimage2D",
88+
"uint64": "uimage2D",
89+
# common dtype aliases
7490
"bool": "uimage2D",
91+
"int": "iimage2D",
92+
"uint": "uimage2D",
7593
},
7694
},
7795
"SAMPLER_T": {
7896
3: {
97+
"double": "sampler3D",
7998
"float": "sampler3D",
8099
"half": "sampler3D",
81-
"int": "isampler3D",
82-
"uint": "usampler3D",
100+
# integer dtypes
83101
"int8": "isampler3D",
84102
"uint8": "usampler3D",
103+
"int16": "isampler3D",
104+
"uint16": "usampler3D",
105+
"int32": "isampler3D",
106+
"uint32": "usampler3D",
107+
"int64": "isampler3D",
108+
"uint64": "usampler3D",
109+
# common dtype aliases
85110
"bool": "usampler3D",
111+
"int": "isampler3D",
112+
"uint": "usampler3D",
86113
},
87114
2: {
115+
"double": "sampler2D",
88116
"float": "sampler2D",
89117
"half": "sampler2D",
90-
"int": "isampler2D",
91-
"uint": "usampler2D",
118+
# integer dtypes
92119
"int8": "isampler2D",
93120
"uint8": "usampler2D",
121+
"int16": "isampler2D",
122+
"uint16": "usampler2D",
123+
"int32": "isampler2D",
124+
"uint32": "usampler2D",
125+
"int64": "isampler2D",
126+
"uint64": "usampler2D",
127+
# common dtype aliases
94128
"bool": "usampler2D",
129+
"int": "isampler2D",
130+
"uint": "usampler2D",
95131
},
96132
},
97133
"IMAGE_FORMAT": {
134+
"double": "rgba32f",
98135
"float": "rgba32f",
99136
"half": "rgba16f",
100-
"int": "rgba32i",
101-
"uint": "rgba32ui",
137+
# integer dtypes
102138
"int8": "rgba8i",
103139
"uint8": "rgba8ui",
140+
"int16": "rgba16i",
141+
"uint16": "rgba16ui",
142+
"int32": "rgba32i",
143+
"uint32": "rgba32ui",
144+
"int64": "rgba32i",
145+
"uint64": "rgba32ui",
146+
# common dtype aliases
104147
"bool": "rgba8ui",
148+
"int": "rgba32i",
149+
"uint": "rgba32ui",
105150
},
106151
}
107152

@@ -118,33 +163,47 @@ def define_variable(name: str) -> str:
118163
def buffer_scalar_type(dtype: str) -> str:
119164
if dtype == "half":
120165
return "float16_t"
121-
elif dtype[-1] == "8":
122-
return dtype + "_t"
166+
elif dtype == "float":
167+
return "float"
168+
elif dtype == "double":
169+
return "float64_t"
170+
# integer dtype alias conversion
123171
elif dtype == "bool":
124172
return "uint8_t"
173+
# we don't want to append _t for int32 or uint32 as int is already 32bit
174+
elif dtype == "int32" or dtype == "uint32":
175+
return "int" if dtype == "int32" else "uint"
176+
elif dtype[-1].isdigit():
177+
return dtype + "_t"
125178
return dtype
126179

127180

128181
def buffer_gvec_type(dtype: str, n: int) -> str:
129182
if n == 1:
130183
return buffer_scalar_type(dtype)
131184

132-
if dtype == "float":
133-
return f"vec{n}"
134-
if dtype == "uint":
135-
return f"uvec{n}"
136-
elif dtype == "half":
137-
return f"f16vec{n}"
138-
elif dtype == "int":
139-
return f"ivec{n}"
140-
elif dtype == "int8":
141-
return f"i8vec{n}"
142-
elif dtype == "uint8":
143-
return f"u8vec{n}"
144-
elif dtype == "bool":
145-
return f"u8vec{n}"
146-
147-
raise AssertionError(f"Invalid dtype: {dtype}")
185+
dtype_map = {
186+
"half": f"f16vec{n}",
187+
"float": f"vec{n}",
188+
"double": f"vec{n}", # No 64bit image format support in GLSL
189+
"int8": f"i8vec{n}",
190+
"uint8": f"u8vec{n}",
191+
"int16": f"i16vec{n}",
192+
"uint16": f"u16vec{n}",
193+
"int32": f"ivec{n}",
194+
"int": f"ivec{n}",
195+
"uint32": f"uvec{n}",
196+
"uint": f"uvec{n}",
197+
"int64": f"ivec{n}", # No 64bit image format support in GLSL
198+
"uint64": f"uvec{n}", # No 64bit image format support in GLSL
199+
"bool": f"u8vec{n}",
200+
}
201+
202+
vector_type = dtype_map.get(dtype)
203+
if vector_type is None:
204+
raise AssertionError(f"Invalid dtype: {dtype}")
205+
206+
return vector_type
148207

149208

150209
def texel_type(dtype: str) -> str:
@@ -365,15 +424,22 @@ def define_required_extensions(dtypes: Union[str, List[str]]):
365424
if dtype == "half":
366425
nbit = "16bit"
367426
glsl_type = "float16"
368-
elif dtype == "int16" or dtype == "uint16":
369-
nbit = "16bit"
370-
glsl_type = "int16"
371-
elif dtype == "int8" or dtype == "uint8" or dtype == "bool":
427+
elif dtype == "double":
428+
# We only need to allow float64_t type usage
429+
glsl_type = "float64"
430+
elif dtype in ["int8", "uint8", "bool"]:
372431
nbit = "8bit"
373432
glsl_type = "int8"
433+
elif dtype in ["int16", "uint16"]:
434+
nbit = "16bit"
435+
glsl_type = "int16"
436+
elif dtype in ["int64", "uint64"]:
437+
# We only need to allow int64_t and uint64_t type usage
438+
glsl_type = "int64"
374439

375-
if nbit is not None and glsl_type is not None:
440+
if nbit is not None:
376441
out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n"
442+
if glsl_type is not None:
377443
out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n"
378444

379445
return out_str
@@ -629,6 +695,10 @@ def generateVariantCombinations(
629695

630696
elif "VALUE" in value:
631697
suffix = value.get("SUFFIX", value["VALUE"])
698+
if value["VALUE"] in ["int", "uint"]:
699+
raise ValueError(
700+
f"Use int32 or uint32 instead of {value['VALUE']}"
701+
)
632702
param_values.append((param_name, suffix, value["VALUE"]))
633703

634704
else:

backends/vulkan/runtime/graph/ops/glsl/arange.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
arange:
88
parameter_names_with_default_values:
99
NDIM: 3
10-
DTYPE: int
10+
DTYPE: int32
1111
STORAGE: texture3d
1212
PACKING: C_packed
1313
generate_variant_forall:
1414
DTYPE:
1515
- VALUE: half
1616
- VALUE: float
17-
- VALUE: int
17+
- VALUE: int32
1818
shader_variants:
1919
- NAME: arange

backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ avg_pool2d:
1313
DTYPE:
1414
- VALUE: half
1515
- VALUE: float
16-
- VALUE: int
16+
- VALUE: int32
1717
shader_variants:
1818
- NAME: avg_pool2d

backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ binary_op:
1717
DTYPE:
1818
- VALUE: half
1919
- VALUE: float
20-
- VALUE: int
20+
- VALUE: int32
2121
shader_variants:
2222
- NAME: binary_add
2323
- NAME: binary_sub

backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ buffer_to_buffer:
1212
DTYPE:
1313
- VALUE: half
1414
- VALUE: float
15-
- VALUE: int
15+
- VALUE: double
1616
- VALUE: int8
1717
- VALUE: uint8
18+
- VALUE: int32
1819
shader_variants:
1920
- NAME: buffer_to_buffer

backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ buffer_to_nchw:
1313
DTYPE:
1414
- VALUE: half
1515
- VALUE: float
16-
- VALUE: int
16+
- VALUE: double
1717
- VALUE: int8
1818
- VALUE: uint8
19+
- VALUE: int32
1920
shader_variants:
2021
- NAME: buffer_to_nchw
2122
- NAME: buffer_to_nchw_no_pc

backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ copy_channel_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: copy_channel_offset

backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ copy_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
- VALUE: int8
1212
- VALUE: uint8
1313
STORAGE:

backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ copy_packed_dim_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: copy_packed_dim_offset

backends/vulkan/runtime/graph/ops/glsl/embedding.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ embedding:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: embedding

backends/vulkan/runtime/graph/ops/glsl/flip.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ flip:
66
DTYPE:
77
- VALUE: half
88
- VALUE: float
9-
- VALUE: int
9+
- VALUE: double
1010
- VALUE: int8
1111
- VALUE: uint8
12+
- VALUE: int32
1213
shader_variants:
1314
- NAME: flip

backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@ image_to_nchw:
1414
DTYPE:
1515
- VALUE: half
1616
- VALUE: float
17-
- VALUE: int
17+
- VALUE: double
1818
- VALUE: int8
1919
- VALUE: uint8
20+
- VALUE: int32
2021
shader_variants:
2122
- NAME: image_to_nchw_texture3d
2223
- NAME: image_to_nchw_texture2d

backends/vulkan/runtime/graph/ops/glsl/index_select.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ index_select:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: index_select

backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ index_select_channel:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: index_select_channel

backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ nchw_to_buffer:
1313
DTYPE:
1414
- VALUE: half
1515
- VALUE: float
16-
- VALUE: int
16+
- VALUE: double
1717
- VALUE: int8
1818
- VALUE: uint8
19+
- VALUE: int32
1920
shader_variants:
2021
- NAME: nchw_to_buffer
2122
- NAME: nchw_to_buffer_no_pc

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,9 @@ void main() {
8787
return;
8888
}
8989

90-
write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx));
90+
$if DTYPE == "double" and DTYPE == "int64":
91+
VEC4_T texel = read_texel(tidx);
92+
write_texel(t_out, lpos_to_pos(lpos, axis_map), texel);
93+
$else:
94+
write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx));
9195
}

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@ nchw_to_image:
1414
DTYPE:
1515
- VALUE: half
1616
- VALUE: float
17-
- VALUE: int
17+
- VALUE: double
1818
- VALUE: int8
1919
- VALUE: uint8
20+
- VALUE: int32
2021
shader_variants:
2122
- NAME: nchw_to_image_texture3d
2223
- NAME: nchw_to_image_texture2d

backends/vulkan/runtime/graph/ops/glsl/no_op.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ no_op:
1212
DTYPE:
1313
- VALUE: half
1414
- VALUE: float
15-
- VALUE: int
15+
- VALUE: int32
1616
- VALUE: int8
1717
- VALUE: uint8
1818
STORAGE:

backends/vulkan/runtime/graph/ops/glsl/permute.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ permute:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: permute

0 commit comments

Comments
 (0)