@@ -96,11 +96,20 @@ def get_vmem_limit(
9696 """Calculate VMEM limit for the kernel."""
9797
9898 # Calculate in/out VMEM size.
99- x_size = batch_block_size * in_block_size * dtypes .bit_width (x_dtype )
100- x_abs_max_size = batch_block_size * dtypes .bit_width (scale_dtype )
101- w_q_size = out_block_size * in_block_size * dtypes .bit_width (w_q_dtype )
102- w_scale_size = out_block_size * dtypes .bit_width (scale_dtype )
103- out_size = batch_block_size * out_block_size * dtypes .bit_width (out_dtype )
99+ x_size = (batch_block_size *
100+ in_block_size * (dtypes .bit_width (x_dtype ) if hasattr (
101+ dtypes , "bit_width" ) else dtypes .itemsize_bits (x_dtype )))
102+ x_abs_max_size = (
103+ batch_block_size * (dtypes .bit_width (scale_dtype ) if hasattr (
104+ dtypes , "bit_width" ) else dtypes .itemsize_bits (scale_dtype )))
105+ w_q_size = (out_block_size *
106+ in_block_size * (dtypes .bit_width (w_q_dtype ) if hasattr (
107+ dtypes , "bit_width" ) else dtypes .itemsize_bits (w_q_dtype )))
108+ w_scale_size = (out_block_size * (dtypes .bit_width (scale_dtype ) if hasattr (
109+ dtypes , "bit_width" ) else dtypes .itemsize_bits (scale_dtype )))
110+ out_size = (batch_block_size *
111+ out_block_size * (dtypes .bit_width (out_dtype ) if hasattr (
112+ dtypes , "bit_width" ) else dtypes .itemsize_bits (out_dtype )))
104113
105114 vmem_in_out = x_size + x_abs_max_size + w_q_size + w_scale_size + out_size
106115 vmem_in_out *= 2 # Account for compute and vreg spills.
@@ -114,9 +123,15 @@ def get_vmem_limit(
114123 vmem_in_out += out_size if (n_batch > 1 or n_out > 1 ) else 0
115124
116125 # Calculate scratch VMEM size.
117- acc_size = batch_block_size * out_block_size * dtypes .bit_width (acc_dtype )
118- x_q_size = batch_block_size * in_block_size * dtypes .bit_width (x_q_dtype )
119- x_scale_size = batch_block_size * dtypes .bit_width (scale_dtype )
126+ acc_size = (batch_block_size *
127+ out_block_size * (dtypes .bit_width (acc_dtype ) if hasattr (
128+ dtypes , "bit_width" ) else dtypes .itemsize_bits (acc_dtype )))
129+ x_q_size = (batch_block_size *
130+ in_block_size * (dtypes .bit_width (x_q_dtype ) if hasattr (
131+ dtypes , "bit_width" ) else dtypes .itemsize_bits (x_q_dtype )))
132+ x_scale_size = (
133+ batch_block_size * (dtypes .bit_width (scale_dtype ) if hasattr (
134+ dtypes , "bit_width" ) else dtypes .itemsize_bits (scale_dtype )))
120135
121136 vmem_scratch = acc_size if save_acc else 0
122137 vmem_scratch += x_q_size + x_scale_size if save_x_q else 0
0 commit comments