Skip to content

Commit 53b5efd

Browse files
authored
Fix Wq size check
Differential Revision: D85219803 Pull Request resolved: #3228
1 parent 204cd48 commit 53b5efd

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#endif
3434

3535
#define OPERATOR_NAME "rowwise_scaled_linear_sparse_cutlass"
36+
#define PAD_TO_MULTIPLE_OF_16(x) (((x) + 15) / 16 * 16)
3637

3738
namespace torchao {
3839

@@ -448,7 +449,7 @@ check_inputs(
448449
// W_meta may be padded, thus expected shape calculations for this
449450
// tensor are as follows.
450451
const auto W_meta_size_0_expected = std::max((int)Wq_sizes[0], 64);
451-
const auto W_meta_size_1_expected = std::max((int)Wq_sizes[1] / 4, 16);
452+
const auto W_meta_size_1_expected = std::max(PAD_TO_MULTIPLE_OF_16((int)Wq_sizes[1]/4), 16);
452453
TORCH_CHECK(W_meta.size(0) == W_meta_size_0_expected, OPERATOR_NAME,
453454
" : Expected Wq meta argument to have ", W_meta_size_0_expected,
454455
" rows, got ", W_meta.size(0), " rows");

0 commit comments

Comments
 (0)