Skip to content

Commit 2a10dab

Browse files
authored
Replacing bit_width() with itemized_bits(). (#1264)
Signed-off-by: Aman Gupta <[email protected]>
1 parent 9a990c4 commit 2a10dab

File tree

10 files changed

+54
-22
lines changed

10 files changed

+54
-22
lines changed

tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,9 @@ def gen_random(shape, dtype):
176176
)
177177
output = output[:cu_q_lens[distribution[-1]]]
178178

179-
dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
179+
dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
180+
dtypes, "bit_width") else dtypes.itemsize_bits(
181+
jnp.dtype(kv_dtype)))
180182
tols = {
181183
32: 0.15,
182184
16: 0.2,

tests/kernels/ragged_paged_attention_kernel_v3_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ def gen_random(shape, dtype):
162162
)
163163
output = output[:cu_q_lens[distribution[-1]]]
164164

165-
dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
165+
dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
166+
dtypes, "bit_width") else dtypes.itemsize_bits(
167+
jnp.dtype(kv_dtype)))
166168
tols = {
167169
32: 0.15,
168170
16: 0.2,

tpu_inference/kernels/collectives/all_gather_matmul.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -540,12 +540,16 @@ def get_vmem_estimate_bytes(
540540
"""Returns the total vmem bytes used by the kernel."""
541541
m_per_device = m // tp_size
542542
n_per_device = n // tp_size
543-
y_vmem_bytes = n_per_device * k * dtypes.bit_width(y_dtype) // 8
543+
y_vmem_bytes = (n_per_device * k * (dtypes.bit_width(y_dtype) if hasattr(
544+
dtypes, "bit_width") else dtypes.itemsize_bits(y_dtype)) // 8)
544545
total_bytes = (
545-
2 * m_per_device * k * dtypes.bit_width(x_dtype) //
546-
8 # x_vmem_scratch_ref
546+
2 * m_per_device * k *
547+
(dtypes.bit_width(x_dtype) if hasattr(dtypes, "bit_width") else
548+
dtypes.itemsize_bits(x_dtype)) // 8 # x_vmem_scratch_ref
547549
+ y_vmem_bytes # y_vmem_scratch_ref
548-
+ 2 * m * bn * dtypes.bit_width(out_dtype) // 8 # o_vmem_scratch_ref
550+
+ 2 * m * bn *
551+
(dtypes.bit_width(out_dtype) if hasattr(dtypes, "bit_width") else
552+
dtypes.itemsize_bits(out_dtype)) // 8 # o_vmem_scratch_ref
549553
+ acc_bytes # acc_vmem_scratch_ref, jnp.float32
550554
)
551555
return total_bytes
@@ -639,8 +643,10 @@ def all_gather_matmul(
639643
# NOTE(chengjiyao): acc buffer is not used in the grid_k == 1 case.
640644
if grid_k == 1:
641645
acc_shape = (8, 128)
642-
acc_bytes = acc_shape[0] * acc_shape[1] * dtypes.bit_width(
643-
jnp.float32) // 8
646+
acc_bytes = (
647+
acc_shape[0] *
648+
acc_shape[1] * (dtypes.bit_width(jnp.float32) if hasattr(
649+
dtypes, "bit_width") else dtypes.itemsize_bits(jnp.float32)) // 8)
644650
y_vmem_shape = (n_per_device, k) if rhs_transpose else (k, n_per_device)
645651
estimated_vmem_bytes = get_vmem_estimate_bytes(
646652
m,

tpu_inference/kernels/fused_moe/v1/kernel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def align_to(x, a):
1919

2020

2121
def get_dtype_packing(dtype):
22-
bits = dtypes.bit_width(dtype)
22+
bits = (dtypes.bit_width(dtype)
23+
if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
2324
return 32 // bits
2425

2526

tpu_inference/kernels/quantized_matmul/kernel.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,20 @@ def get_vmem_limit(
9696
"""Calculate VMEM limit for the kernel."""
9797

9898
# Calculate in/out VMEM size.
99-
x_size = batch_block_size * in_block_size * dtypes.bit_width(x_dtype)
100-
x_abs_max_size = batch_block_size * dtypes.bit_width(scale_dtype)
101-
w_q_size = out_block_size * in_block_size * dtypes.bit_width(w_q_dtype)
102-
w_scale_size = out_block_size * dtypes.bit_width(scale_dtype)
103-
out_size = batch_block_size * out_block_size * dtypes.bit_width(out_dtype)
99+
x_size = (batch_block_size *
100+
in_block_size * (dtypes.bit_width(x_dtype) if hasattr(
101+
dtypes, "bit_width") else dtypes.itemsize_bits(x_dtype)))
102+
x_abs_max_size = (
103+
batch_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
104+
dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
105+
w_q_size = (out_block_size *
106+
in_block_size * (dtypes.bit_width(w_q_dtype) if hasattr(
107+
dtypes, "bit_width") else dtypes.itemsize_bits(w_q_dtype)))
108+
w_scale_size = (out_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
109+
dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
110+
out_size = (batch_block_size *
111+
out_block_size * (dtypes.bit_width(out_dtype) if hasattr(
112+
dtypes, "bit_width") else dtypes.itemsize_bits(out_dtype)))
104113

105114
vmem_in_out = x_size + x_abs_max_size + w_q_size + w_scale_size + out_size
106115
vmem_in_out *= 2 # Account for compute and vreg spills.
@@ -114,9 +123,15 @@ def get_vmem_limit(
114123
vmem_in_out += out_size if (n_batch > 1 or n_out > 1) else 0
115124

116125
# Calculate scratch VMEM size.
117-
acc_size = batch_block_size * out_block_size * dtypes.bit_width(acc_dtype)
118-
x_q_size = batch_block_size * in_block_size * dtypes.bit_width(x_q_dtype)
119-
x_scale_size = batch_block_size * dtypes.bit_width(scale_dtype)
126+
acc_size = (batch_block_size *
127+
out_block_size * (dtypes.bit_width(acc_dtype) if hasattr(
128+
dtypes, "bit_width") else dtypes.itemsize_bits(acc_dtype)))
129+
x_q_size = (batch_block_size *
130+
in_block_size * (dtypes.bit_width(x_q_dtype) if hasattr(
131+
dtypes, "bit_width") else dtypes.itemsize_bits(x_q_dtype)))
132+
x_scale_size = (
133+
batch_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
134+
dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
120135

121136
vmem_scratch = acc_size if save_acc else 0
122137
vmem_scratch += x_q_size + x_scale_size if save_x_q else 0

tpu_inference/kernels/ragged_paged_attention/v2/kernel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,8 @@ def cdiv(a, b):
655655

656656

657657
def get_dtype_packing(dtype):
658-
bits = dtypes.bit_width(dtype)
658+
bits = (dtypes.bit_width(dtype)
659+
if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
659660
return 32 // bits
660661

661662

tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ def _prev_power_of_2(n: int) -> int:
200200
def _get_page_size_bytes(block_size: int, num_combined_kv_heads: int,
201201
head_size: int, kv_cache_dtype) -> int:
202202
"""Returns the size in bytes of one page of the KV cache."""
203-
kv_cache_dtype_bit_size = dtypes.bit_width(kv_cache_dtype)
203+
kv_cache_dtype_bit_size = (dtypes.bit_width(kv_cache_dtype) if hasattr(
204+
dtypes, "bit_width") else dtypes.itemsize_bits(kv_cache_dtype))
204205
padded_head_size = _ceil_div(
205206
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
206207

tpu_inference/kernels/ragged_paged_attention/v3/util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ def align_to(x, a):
1313

1414

1515
def get_dtype_bitwidth(dtype):
16-
return dtypes.bit_width(dtype)
16+
return (dtypes.bit_width(dtype)
17+
if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
1718

1819

1920
def get_dtype_packing(dtype):

tpu_inference/runner/kv_cache.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,10 @@ def get_attention_page_size_bytes(mesh: Mesh,
131131
assert isinstance(kv_cache_spec, AttentionSpec)
132132

133133
dtype = t2j_dtype(kv_cache_spec.dtype)
134-
bits = dtypes.bit_width(dtype)
134+
bits = (dtypes.bit_width(dtype) if hasattr(dtypes, "bit_width") else
135+
dtypes.itemsize_bits(dtype))
135136
use_mla = isinstance(kv_cache_spec, MLAAttentionSpec)
137+
136138
kv_cache_shape = get_kv_cache_shape_with_mesh(
137139
mesh=mesh,
138140
total_num_pages=1, # Pass 1 to get shape of a single page.

tpu_inference/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,8 @@ def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
190190

191191

192192
def get_dtype_packing(dtype):
193-
bits = dtypes.bit_width(dtype)
193+
bits = (dtypes.bit_width(dtype)
194+
if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
194195
return 32 // bits
195196

196197

0 commit comments

Comments
 (0)