Skip to content

Commit 7078cb8

Browse files
committed
Fix the bugs about operator registration about PyTorch Dispatcher
**Background:** There are two principles about operator registration in PyTorch - The same namespace can be only registered once by `TORCH_LIBRARY` - The operator signatures can be only registered once by `def` Therefore, - for the first problem, we can use `TORCH_LIBRARY_FRAGMEN` to expand operators within the same NAMESPACE. - for the second problem, the best way to fix it is to define all the general operator schemas in vLLM insteal of in every plugin repo. Signed-off-by: FFFrog <[email protected]>
1 parent 0c0789b commit 7078cb8

File tree

8 files changed

+79
-59
lines changed

8 files changed

+79
-59
lines changed

csrc/torch_binding.cpp

Lines changed: 36 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ AscendType get_dtype_from_torch(at::ScalarType scalarType)
3838
}
3939
}
4040

41-
std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
41+
void rotary_embedding(at::Tensor &positions, at::Tensor &query, std::optional<at::Tensor> key,
4242
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
4343
{
4444
int32_t deviceId = 0;
@@ -47,22 +47,23 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
4747
TORCH_CHECK(
4848
positions_ndim == 1 || positions_ndim == 2,
4949
"positions must have shape [num_tokens] or [batch_size, seq_len]");
50+
TORCH_CHECK(key.has_value(), "rotary_embedding: key must have a value");
5051
if (positions_ndim == 1) {
5152
TORCH_CHECK(
52-
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
53+
query.size(0) == positions.size(0) && key.value().size(0) == positions.size(0),
5354
"query, key and positions must have the same number of tokens");
5455
}
5556
if (positions_ndim == 2) {
5657
TORCH_CHECK(
5758
query.size(0) == positions.size(0) &&
58-
key.size(0) == positions.size(0) &&
59+
key.value().size(0) == positions.size(0) &&
5960
query.size(1) == positions.size(1) &&
60-
key.size(1) == positions.size(1),
61+
key.value().size(1) == positions.size(1),
6162
"query, key and positions must have the same batch_size and seq_len");
6263
}
6364
TORCH_CHECK(head_size % 32 == 0, "rotary_embedding: headSize should be divisible by 32");
6465
int query_hidden_size = query.numel() / num_tokens;
65-
int key_hidden_size = key.numel() / num_tokens;
66+
int key_hidden_size = key.value().numel() / num_tokens;
6667
TORCH_CHECK(query_hidden_size % head_size == 0);
6768
TORCH_CHECK(key_hidden_size % head_size == 0);
6869
TORCH_CHECK(is_neox == true, "rotary_embedding: neox=false is not supported as custom kernel in vllm-ascend");
@@ -72,18 +73,18 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
7273
int num_kv_heads = key_hidden_size / head_size;
7374
TORCH_CHECK(num_heads % num_kv_heads == 0);
7475
at::Tensor query_dst = at::empty({num_tokens, num_heads, head_size}, query.options());
75-
at::Tensor key_dst = at::empty({num_tokens, num_kv_heads, head_size}, key.options());
76+
at::Tensor key_dst = at::empty({num_tokens, num_kv_heads, head_size}, key.value().options());
7677

7778
int rot_dim = cos_sin_cache.size(1);
7879
int seq_dim_idx = positions_ndim - 1;
7980
int64_t *position_ids_ptr = positions.data_ptr<int64_t>();
8081
void *query_dst_ptr = query_dst.data_ptr();
8182
void *key_dst_ptr = key_dst.data_ptr();
8283
void *query_ptr = query.data_ptr();
83-
void *key_ptr = key.data_ptr();
84+
void *key_ptr = key.value().data_ptr();
8485
void *cos_sin_cache_ptr = cos_sin_cache.data_ptr();
8586
int64_t query_stride = query.stride(seq_dim_idx);
86-
int64_t key_stride = key.stride(seq_dim_idx);
87+
int64_t key_stride = key.value().stride(seq_dim_idx);
8788
int64_t dst_query_stride = query_dst.stride(0);
8889
int64_t dst_key_stride = key_dst.stride(0);
8990
at::ScalarType scalar_type = query.scalar_type();
@@ -104,7 +105,9 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
104105
return 0;
105106
});
106107
cmd.Run();
107-
return {query_dst, key_dst};
108+
109+
query.copy_(query_dst);
110+
key.value().copy_(key_dst);
108111
}
109112

110113
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
@@ -142,7 +145,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
142145
TP2, rank 1:
143146
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
144147
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
145-
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
148+
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
146149
Parameters:
147150
org_vocab_start_index //base embeddings start
148151
org_vocab_end_index //base embeddings end
@@ -165,22 +168,22 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
165168
// Create output tensors
166169
at::Tensor masked_input = at::empty_like(input);
167170
at::Tensor mask = at::empty_like(input).to(at::kBool);
168-
171+
169172
// Get data pointers
170173
void *input_ptr = input.data_ptr();
171174
void *masked_input_ptr = masked_input.data_ptr();
172175
void *mask_ptr = mask.data_ptr();
173-
176+
174177
// Get current stream
175178
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
176-
179+
177180
// Get scalar type
178181
at::ScalarType scalar_type = input.scalar_type();
179-
182+
180183
// Create and configure OpCommand
181184
at_npu::native::OpCommand cmd;
182185
cmd.Name("get_masked_input_and_mask");
183-
cmd.SetCustomHandler([scalar_type, size, stream,
186+
cmd.SetCustomHandler([scalar_type, size, stream,
184187
input_ptr, masked_input_ptr, mask_ptr,
185188
org_vocab_start_index, org_vocab_end_index,
186189
num_org_vocab_padding, added_vocab_start_index,
@@ -194,7 +197,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
194197
get_masked_input_and_mask_impl(
195198
stream,
196199
input_ptr,
197-
masked_input_ptr,
200+
masked_input_ptr,
198201
mask_ptr,
199202
org_vocab_start_index,
200203
org_vocab_end_index,
@@ -204,7 +207,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
204207
size,
205208
loop_cnt,
206209
aiv_num);
207-
210+
208211
return 0;
209212
});
210213
cmd.Run();
@@ -321,8 +324,8 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at
321324
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
322325
at_npu::native::OpCommand cmd;
323326
cmd.Name("sgmv_shrink");
324-
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size,
325-
seq_len_ptr, seq_len_size, y_ptr,
327+
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size,
328+
seq_len_ptr, seq_len_size, y_ptr,
326329
batch_size, input_hidden_token, lora_rank, scale_f]() -> int {
327330
auto dtype = get_dtype_from_torch(scalar_type);
328331
int device_id = 0;
@@ -331,7 +334,7 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at
331334
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
332335
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
333336
sgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size,
334-
y_ptr, batch_size,
337+
y_ptr, batch_size,
335338
num_tokens_per_core, input_hidden_token, lora_rank, scale_f);
336339
return 0;
337340
});
@@ -368,15 +371,15 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
368371
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
369372
at_npu::native::OpCommand cmd;
370373
cmd.Name("sgmv_expand");
371-
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
374+
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
372375
batch_size, lora_rank, slice_offset, slice_size, output_full_dim]() -> int {
373376
auto dtype = get_dtype_from_torch(scalar_type);
374377
int device_id = 0;
375378
int64_t aiv_num = 0;
376379
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
377380
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
378381
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
379-
sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
382+
sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
380383
batch_size, num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
381384
return 0;
382385
});
@@ -385,43 +388,34 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
385388
}
386389
} // namespace vllm_ascend
387390

388-
TORCH_LIBRARY_EXPAND(_C, ops)
391+
TORCH_LIBRARY_FRAGMENT_EXPAND(_C, ops)
389392
{
390-
// vLLM-Ascend custom ops
391393
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
392-
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
393-
394-
// Rotary embedding
395-
// Apply GPT-NeoX style rotary embedding to query and key.
396-
ops.def(
397-
"rotary_embedding(Tensor positions, Tensor! query,"
398-
" Tensor! key, int head_size,"
399-
" Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)");
400-
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
401-
402394
ops.def(
403395
"get_masked_input_and_mask(Tensor input, "
404396
" int org_vocab_start_index, "
405397
" int org_vocab_end_index, "
406398
" int num_org_vocab_padding, "
407399
" int added_vocab_start_index, "
408400
" int added_vocab_end_index) -> (Tensor masked_input, Tensor mask)");
409-
ops.impl("get_masked_input_and_mask", torch::kPrivateUse1, &vllm_ascend::get_masked_input_and_mask);
410-
411401
ops.def("bgmv_shrink(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y, float scale) -> ()");
412-
ops.impl("bgmv_shrink", torch::kPrivateUse1, &vllm_ascend::bgmv_shrink);
413-
414402
ops.def(
415403
"bgmv_expand(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y,"
416404
" int slice_offset, int slice_size) -> Tensor");
417-
ops.impl("bgmv_expand", torch::kPrivateUse1, &vllm_ascend::bgmv_expand);
418-
419405
ops.def("sgmv_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y, float scale) -> ()");
420-
ops.impl("sgmv_shrink", torch::kPrivateUse1, &vllm_ascend::sgmv_shrink);
421-
422406
ops.def(
423407
"sgmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y,"
424408
" int slice_offset, int slice_size) -> Tensor");
409+
}
410+
411+
TORCH_LIBRARY_IMPL_EXPAND(_C, PrivateUse1, ops)
412+
{
413+
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
414+
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
415+
ops.impl("get_masked_input_and_mask", torch::kPrivateUse1, &vllm_ascend::get_masked_input_and_mask);
416+
ops.impl("bgmv_shrink", torch::kPrivateUse1, &vllm_ascend::bgmv_shrink);
417+
ops.impl("bgmv_expand", torch::kPrivateUse1, &vllm_ascend::bgmv_expand);
418+
ops.impl("sgmv_shrink", torch::kPrivateUse1, &vllm_ascend::sgmv_shrink);
425419
ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand);
426420
}
427421

csrc/torch_binding_meta.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,25 @@
3636
namespace vllm_ascend {
3737
namespace meta {
3838

39-
std::tuple<at::Tensor, at::Tensor> rotary_embedding_meta(
39+
void rotary_embedding_meta(
4040
at::Tensor &positions,
4141
at::Tensor &query,
42-
at::Tensor &key,
42+
std::optional<at::Tensor> key,
4343
int64_t head_size,
4444
at::Tensor &cos_sin_cache,
4545
bool is_neox) {
46+
TORCH_CHECK(key.has_value(), "rotary_embedding_meta: key must have a value");
4647
auto num_tokens = positions.sym_numel();
4748
auto query_hidden_size = query.sym_numel() / num_tokens;
48-
auto key_hidden_size = key.sym_numel() / num_tokens;
49+
auto key_hidden_size = key.value().sym_numel() / num_tokens;
4950

5051
auto num_heads = query_hidden_size / head_size;
5152
auto num_kv_heads = key_hidden_size / head_size;
52-
at::Tensor query_dst = at::empty_symint({num_tokens, num_heads, head_size}, query.options());
53-
at::Tensor key_dst = at::empty_symint({num_tokens, num_kv_heads, head_size}, key.options());
5453

55-
return {query_dst, key_dst};
54+
c10::SymIntArrayRef query_shape({num_tokens, num_heads, head_size});
55+
c10::SymIntArrayRef key_shape({num_tokens, num_kv_heads, head_size});
56+
query.resize__symint(query_shape);
57+
key.value().resize__symint(key_shape);
5658
}
5759

5860
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask_meta(
@@ -99,4 +101,4 @@ namespace {
99101
ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta);
100102

101103
}
102-
}
104+
}

csrc/utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
1414
// could be a macro instead of a literal token.
1515
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
16+
//
17+
// A version of the TORCH_LIBRARY_FRAGMENT macro that expands the NAME, i.e. so NAME
18+
// could be a macro instead of a literal token.
19+
#define TORCH_LIBRARY_FRAGMENT_EXPAND(NAME, MODULE) TORCH_LIBRARY_FRAGMENT(NAME, MODULE)
1620

1721
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
1822
// could be a macro instead of a literal token.

tests/e2e/singlecard/ops/test_rotary_embedding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def test_rotary_embedding_quant_with_leading_dim(
182182
)
183183

184184
ref_query, ref_key = rope.forward_native(positions, query, key)
185-
query, key = torch.ops._C.rotary_embedding(
185+
torch.ops._C.rotary_embedding(
186186
positions,
187187
query,
188188
key,
@@ -239,16 +239,16 @@ def forward(
239239
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
240240
qkv = self.qkv_proj(hidden_states)
241241
q, k, v = qkv.chunk(3, dim=-1)
242-
query, key = torch.ops._C.rotary_embedding(
242+
q_shape = q.shape
243+
torch.ops._C.rotary_embedding(
243244
positions,
244245
q,
245246
k,
246247
self.rope.head_size,
247248
self.rope.cos_sin_cache,
248249
self.rope.is_neox_style,
249250
)
250-
query = query.view(q.shape)
251-
key = key.view(k.shape)
251+
query = q.view(q_shape)
252252
o = self.o_proj(query)
253253
return o
254254

tests/ut/test_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,20 @@ def test_update_aclgraph_sizes(self):
261261
self.assertEqual(
262262
147,
263263
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
264+
265+
test_vllm_config.speculative_config = mock.MagicMock()
266+
test_vllm_config.speculative_config.draft_model_config = mock.MagicMock(
267+
)
268+
test_vllm_config.speculative_config.draft_model_config.hf_config = mock.MagicMock(
269+
)
270+
test_vllm_config.speculative_config.draft_model_config.hf_config.num_hidden_layers = 2
271+
os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV'
272+
utils.update_aclgraph_sizes(test_vllm_config)
273+
del os.environ['HCCL_OP_EXPANSION_MODE']
274+
self.assertEqual(
275+
120,
276+
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
277+
264278
# max_num_batch_sizes >= len(original_sizes)
265279
test_compilation_config = CompilationConfig(
266280
cudagraph_capture_sizes=[1, 2, 3])

vllm_ascend/ops/rotary_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _rope_forward_oot(
5151
# adopt custom kernel path for rotary_embedding
5252
if _custom_rotary_embedding_enabled(query, neox_style,
5353
self.head_size) and not is_310p():
54-
query, key = torch.ops._C.rotary_embedding(
54+
torch.ops._C.rotary_embedding(
5555
positions,
5656
query,
5757
key,

vllm_ascend/torchair/ops/torchair_rotary_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def rope_forward_oot(
6262
# adopt custom kernel path for rotary_embedding
6363
if custom_rotary_embedding_enabled(query, neox_style,
6464
self.head_size) and not is_310p():
65-
query, key = torch.ops._C.rotary_embedding(
65+
torch.ops._C.rotary_embedding(
6666
positions,
6767
query,
6868
key,

vllm_ascend/utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,12 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
304304
num_hidden_layers = get_max_hidden_layers(hf_config)
305305
parallel_config = vllm_config.parallel_config
306306

307+
# Calculate maximum supported batch sizes considering model architecture
308+
resources_per_graph = num_hidden_layers + 1
309+
if vllm_config.speculative_config is not None:
310+
draft_model_hf_config = vllm_config.speculative_config.draft_model_config.hf_config
311+
resources_per_graph += draft_model_hf_config.num_hidden_layers + 1
312+
307313
# TODO: Find out whether we need to take into account the pp_size
308314
num_comm_groups = sum(size > 1 for size in [
309315
parallel_config.data_parallel_size,
@@ -318,8 +324,8 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
318324
# Assume the following case:
319325
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
320326
# According to the formula, max_num_batch_sizes = math.floor(1920 / (48 + 1) / 2) = 19
321-
max_num_batch_sizes = math.floor(
322-
MAX_CAPTURE_SIZE / (num_hidden_layers + 1) / parallel_factor)
327+
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
328+
resources_per_graph / parallel_factor)
323329
logger.info(
324330
"Calculated maximum supported batch sizes for ACL graph: %s",
325331
max_num_batch_sizes)
@@ -335,8 +341,8 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
335341
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
336342
# According to the formula, max_num_batch_sizes = math.floor((1920 - 1 * 40) / (48 + 1) / (1 + 1 * 2)) = 12
337343
max_num_batch_sizes = math.floor(
338-
(MAX_CAPTURE_SIZE - num_comm_groups * 40) /
339-
(num_hidden_layers + 1) / (1 + num_comm_groups * 2))
344+
(MAX_CAPTURE_SIZE - num_comm_groups * 40) / resources_per_graph /
345+
(1 + num_comm_groups * 2))
340346
logger.info(
341347
"Calculated maximum supported batch sizes for ACL graph: %s",
342348
max_num_batch_sizes)

0 commit comments

Comments
 (0)