Skip to content

Commit baead80

Browse files
committed
fix
Signed-off-by: Yaoyao Ding <[email protected]>
1 parent ab7bc61 commit baead80

File tree

5 files changed

+34
-18
lines changed

5 files changed

+34
-18
lines changed

python/tilus/ir/layout/inference/inference_rules/transform_shared.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ def inference(ctx: LayoutInferenceContext, inst: SliceSharedInst) -> dict[Shared
3838
outer_shape = []
3939
for i in range(len(a.shape)):
4040
outer_shape.append(a.shape[i] // b_layout.shape[i])
41-
return {a: shared_compose(shared_row_major(*outer_shape), b_layout).apply_swizzle(b.layout.swizzle)}
41+
layout = shared_compose(shared_row_major(*outer_shape), b_layout)
42+
if b.layout.optional_swizzle is not None:
43+
layout = layout.apply_swizzle(b.layout.swizzle)
44+
return {a: layout}
4245
else:
4346
return {}
4447

python/tilus/ir/layout/ops/shared_ops.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def shared_row_major(*shape: int) -> SharedLayout:
6161
"""
6262
mode_shape = shape
6363
mode_strides = strides_from_ranks(shape=mode_shape, ranks=list(range(len(mode_shape))))
64-
return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=None)
64+
return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, optional_swizzle=None)
6565

6666

6767
def shared_column_major(*shape: int) -> SharedLayout:
@@ -79,7 +79,7 @@ def shared_column_major(*shape: int) -> SharedLayout:
7979
"""
8080
mode_shape = shape
8181
mode_strides = strides_from_ranks(shape=mode_shape, ranks=list(reversed(range(len(mode_shape)))))
82-
return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=None)
82+
return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, optional_swizzle=None)
8383

8484

8585
def shared_compose(lhs: SharedLayout, rhs: SharedLayout) -> SharedLayout:
@@ -118,7 +118,7 @@ def shared_compose(lhs: SharedLayout, rhs: SharedLayout) -> SharedLayout:
118118
mode_strides.extend([stride * rhs_size for stride in (lhs.mode_strides[i] for i in lhs_group)])
119119
mode_strides.extend([rhs.mode_strides[i] for i in rhs_group])
120120

121-
return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=None)
121+
return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, optional_swizzle=None)
122122

123123

124124
def shared_permute(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout:
@@ -151,7 +151,9 @@ def shared_permute(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout:
151151
mode_shape.extend([layout.mode_shape[i] for i in layout_mode_groups[d]])
152152
mode_strides.extend([layout.mode_strides[i] for i in layout_mode_groups[d]])
153153

154-
return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=layout.swizzle)
154+
return shared_layout(
155+
shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, optional_swizzle=layout.optional_swizzle
156+
)
155157

156158

157159
def shared_slice(layout: SharedLayout, retain_dims: Sequence[int]) -> SharedLayout:
@@ -184,7 +186,7 @@ def shared_slice(layout: SharedLayout, retain_dims: Sequence[int]) -> SharedLayo
184186
shape=shape,
185187
mode_shape=mode_shape,
186188
mode_strides=mode_strides,
187-
swizzle=layout.swizzle,
189+
optional_swizzle=layout.optional_swizzle,
188190
)
189191

190192

@@ -211,7 +213,7 @@ def shared_unsqueeze(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout:
211213
shape=shape,
212214
mode_shape=layout.mode_shape,
213215
mode_strides=layout.mode_strides,
214-
swizzle=layout.swizzle,
216+
optional_swizzle=layout.optional_swizzle,
215217
)
216218

217219

@@ -373,7 +375,7 @@ def shared_row_major_swizzle(shape: Sequence[int], dtype_nbytes: int) -> SharedL
373375
shape=shape,
374376
mode_shape=mode_shape,
375377
mode_strides=mode_strides,
376-
swizzle=swizzle,
378+
optional_swizzle=swizzle,
377379
)
378380

379381

python/tilus/ir/layout/shared_layout.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,10 @@ def swizzle(self) -> Swizzle:
126126

127127
@staticmethod
128128
def create(
129-
shape: Sequence[int], mode_shape: Sequence[int], mode_strides: Sequence[int], swizzle: Optional[Swizzle]
129+
shape: Sequence[int],
130+
mode_shape: Sequence[int],
131+
mode_strides: Sequence[int],
132+
optional_swizzle: Optional[Swizzle],
130133
) -> SharedLayout:
131134
"""
132135
Create a SharedLayout from shape, mode_shape, and mode_strides.
@@ -154,7 +157,10 @@ def create(
154157
if prod(mode_shape) != prod(shape):
155158
raise ValueError("The product of mode_shape must equal to the product of shape.")
156159
return SharedLayout(
157-
shape=tuple(shape), mode_shape=tuple(mode_shape), mode_strides=tuple(mode_strides), optional_swizzle=swizzle
160+
shape=tuple(shape),
161+
mode_shape=tuple(mode_shape),
162+
mode_strides=tuple(mode_strides),
163+
optional_swizzle=optional_swizzle,
158164
)
159165

160166
def as_numpy_grid(self) -> np.ndarray:
@@ -195,7 +201,7 @@ def apply_swizzle(self, swizzle: Swizzle) -> SharedLayout:
195201
shape=self.shape,
196202
mode_shape=self.mode_shape,
197203
mode_strides=self.mode_strides,
198-
swizzle=swizzle,
204+
optional_swizzle=swizzle,
199205
)
200206

201207
def prepend_dim(self, extent: int) -> SharedLayout:
@@ -211,7 +217,7 @@ def prepend_dim(self, extent: int) -> SharedLayout:
211217
shape=shape,
212218
mode_shape=mode_shape,
213219
mode_strides=mode_strides,
214-
swizzle=self.optional_swizzle,
220+
optional_swizzle=self.optional_swizzle,
215221
)
216222

217223
def transpose(self) -> SharedLayout:
@@ -238,7 +244,7 @@ def shared_layout(
238244
shape: Sequence[int],
239245
mode_shape: Sequence[int],
240246
mode_strides: Sequence[int],
241-
swizzle: Optional[Swizzle] = None,
247+
optional_swizzle: Optional[Swizzle] = None,
242248
) -> SharedLayout:
243249
"""Create a SharedLayout from shape, mode_shape, and mode_strides.
244250
@@ -270,7 +276,9 @@ def shared_layout(
270276
mode_strides = updated_mode_strides
271277

272278
# canonicalize swizzle: if swizzle has 0 bits, set it to None (both mean no swizzle)
273-
if swizzle is not None and swizzle.bits == 0:
274-
swizzle = None
279+
if optional_swizzle is not None and optional_swizzle.bits == 0:
280+
optional_swizzle = None
275281

276-
return SharedLayout.create(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=swizzle)
282+
return SharedLayout.create(
283+
shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, optional_swizzle=optional_swizzle
284+
)

python/tilus/ir/layout/utils/cute.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,10 @@ def flatten_int_tuple(t: IntTuple) -> list[Int]:
156156
mode_strides = [int(s) for s in flat_strides]
157157

158158
return shared_layout(
159-
shape=tensor_shape, mode_shape=mode_shape, mode_strides=mode_strides, swizzle=self.swizzle.as_swizzle()
159+
shape=tensor_shape,
160+
mode_shape=mode_shape,
161+
mode_strides=mode_strides,
162+
optional_swizzle=self.swizzle.as_swizzle(),
160163
)
161164

162165

python/tilus/ir/tools/printer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def visit_SharedLayout(self, node: SharedLayout) -> Doc:
452452
"shape=[" + self(node.shape) + "]",
453453
"mode_shape=[" + self(node.mode_shape) + "]",
454454
"mode_strides=[" + self(node.mode_strides) + "]",
455-
"swizzle=" + (str(node.swizzle) if node.swizzle is not None else "None"),
455+
"swizzle=" + (str(node.swizzle) if node.optional_swizzle is not None else "None"),
456456
]
457457
doc = Text("SharedLayout(") + doc_join(items, ", ") + ")"
458458
return self.add_key_comment("shared_layout", doc)

0 commit comments

Comments
 (0)