Skip to content

Commit 78344d6

Browse files
committed
Fix the bugs about operator registration about PyTorch Dispatcher
**Background:** There are two principles about operator registration in PyTorch - The same namespace can be only registered once by `TORCH_LIBRARY` - The operator signatures can be only registered once by `def` Therefore, for the first problem, we can use `TORCH_LIBRARY_FRAGMEN` to expand operators within the same NAMESPACE. For the second problem, operator signatures are not required for registration, so we only need to explicitly define operators that are only defined in vllm-ascend. Signed-off-by: fffrog <[email protected]>
1 parent 7d47d8f commit 78344d6

File tree

6 files changed

+45
-43
lines changed

6 files changed

+45
-43
lines changed

csrc/torch_binding.cpp

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ AscendType get_dtype_from_torch(at::ScalarType scalarType)
3838
}
3939
}
4040

41-
std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
41+
void rotary_embedding(at::Tensor &positions, at::Tensor &query, std::optional<at::Tensor> key,
4242
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
4343
{
4444
int32_t deviceId = 0;
@@ -47,22 +47,23 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
4747
TORCH_CHECK(
4848
positions_ndim == 1 || positions_ndim == 2,
4949
"positions must have shape [num_tokens] or [batch_size, seq_len]");
50+
TORCH_CHECK(key.has_value(), "rotary_embedding: key must have a value");
5051
if (positions_ndim == 1) {
5152
TORCH_CHECK(
52-
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
53+
query.size(0) == positions.size(0) && key.value().size(0) == positions.size(0),
5354
"query, key and positions must have the same number of tokens");
5455
}
5556
if (positions_ndim == 2) {
5657
TORCH_CHECK(
5758
query.size(0) == positions.size(0) &&
58-
key.size(0) == positions.size(0) &&
59+
key.value().size(0) == positions.size(0) &&
5960
query.size(1) == positions.size(1) &&
60-
key.size(1) == positions.size(1),
61+
key.value().size(1) == positions.size(1),
6162
"query, key and positions must have the same batch_size and seq_len");
6263
}
6364
TORCH_CHECK(head_size % 32 == 0, "rotary_embedding: headSize should be divisible by 32");
6465
int query_hidden_size = query.numel() / num_tokens;
65-
int key_hidden_size = key.numel() / num_tokens;
66+
int key_hidden_size = key.value().numel() / num_tokens;
6667
TORCH_CHECK(query_hidden_size % head_size == 0);
6768
TORCH_CHECK(key_hidden_size % head_size == 0);
6869
TORCH_CHECK(is_neox == true, "rotary_embedding: neox=false is not supported as custom kernel in vllm-ascend");
@@ -72,18 +73,18 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
7273
int num_kv_heads = key_hidden_size / head_size;
7374
TORCH_CHECK(num_heads % num_kv_heads == 0);
7475
at::Tensor query_dst = at::empty({num_tokens, num_heads, head_size}, query.options());
75-
at::Tensor key_dst = at::empty({num_tokens, num_kv_heads, head_size}, key.options());
76+
at::Tensor key_dst = at::empty({num_tokens, num_kv_heads, head_size}, key.value().options());
7677

7778
int rot_dim = cos_sin_cache.size(1);
7879
int seq_dim_idx = positions_ndim - 1;
7980
int64_t *position_ids_ptr = positions.data_ptr<int64_t>();
8081
void *query_dst_ptr = query_dst.data_ptr();
8182
void *key_dst_ptr = key_dst.data_ptr();
8283
void *query_ptr = query.data_ptr();
83-
void *key_ptr = key.data_ptr();
84+
void *key_ptr = key.value().data_ptr();
8485
void *cos_sin_cache_ptr = cos_sin_cache.data_ptr();
8586
int64_t query_stride = query.stride(seq_dim_idx);
86-
int64_t key_stride = key.stride(seq_dim_idx);
87+
int64_t key_stride = key.value().stride(seq_dim_idx);
8788
int64_t dst_query_stride = query_dst.stride(0);
8889
int64_t dst_key_stride = key_dst.stride(0);
8990
at::ScalarType scalar_type = query.scalar_type();
@@ -104,7 +105,9 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
104105
return 0;
105106
});
106107
cmd.Run();
107-
return {query_dst, key_dst};
108+
109+
query.copy_(query_dst);
110+
key.value().copy_(key_dst);
108111
}
109112

110113
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
@@ -385,43 +388,36 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
385388
}
386389
} // namespace vllm_ascend
387390

388-
TORCH_LIBRARY_EXPAND(_C, ops)
391+
TORCH_LIBRARY_FRAGMENT_EXPAND(_C, ops)
389392
{
390-
// vLLM-Ascend custom ops
391-
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
392-
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
393-
394-
// Rotary embedding
395-
// Apply GPT-NeoX style rotary embedding to query and key.
396-
ops.def(
397-
"rotary_embedding(Tensor positions, Tensor! query,"
398-
" Tensor! key, int head_size,"
399-
" Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)");
400-
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
401-
402393
ops.def(
403394
"get_masked_input_and_mask(Tensor input, "
404395
" int org_vocab_start_index, "
405396
" int org_vocab_end_index, "
406397
" int num_org_vocab_padding, "
407398
" int added_vocab_start_index, "
408399
" int added_vocab_end_index) -> (Tensor masked_input, Tensor mask)");
409-
ops.impl("get_masked_input_and_mask", torch::kPrivateUse1, &vllm_ascend::get_masked_input_and_mask);
410-
411400
ops.def("bgmv_shrink(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y, float scale) -> ()");
412-
ops.impl("bgmv_shrink", torch::kPrivateUse1, &vllm_ascend::bgmv_shrink);
413-
414401
ops.def(
415402
"bgmv_expand(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y,"
416403
" int slice_offset, int slice_size) -> Tensor");
417-
ops.impl("bgmv_expand", torch::kPrivateUse1, &vllm_ascend::bgmv_expand);
418-
419404
ops.def("sgmv_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y, float scale) -> ()");
420-
ops.impl("sgmv_shrink", torch::kPrivateUse1, &vllm_ascend::sgmv_shrink);
421-
422405
ops.def(
423406
"sgmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y,"
424407
" int slice_offset, int slice_size) -> Tensor");
408+
}
409+
410+
TORCH_LIBRARY_IMPL_EXPAND(_C, PrivateUse1, ops)
411+
{
412+
// vLLM-Ascend custom ops
413+
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
414+
// Rotary embedding
415+
// Apply GPT-NeoX style rotary embedding to query and key.
416+
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
417+
ops.impl("get_masked_input_and_mask", torch::kPrivateUse1, &vllm_ascend::get_masked_input_and_mask);
418+
ops.impl("bgmv_shrink", torch::kPrivateUse1, &vllm_ascend::bgmv_shrink);
419+
ops.impl("bgmv_expand", torch::kPrivateUse1, &vllm_ascend::bgmv_expand);
420+
ops.impl("sgmv_shrink", torch::kPrivateUse1, &vllm_ascend::sgmv_shrink);
425421
ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand);
426422
}
427423

csrc/torch_binding_meta.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,25 @@
3636
namespace vllm_ascend {
3737
namespace meta {
3838

39-
std::tuple<at::Tensor, at::Tensor> rotary_embedding_meta(
39+
void rotary_embedding_meta(
4040
at::Tensor &positions,
4141
at::Tensor &query,
42-
at::Tensor &key,
42+
std::optional<at::Tensor> key,
4343
int64_t head_size,
4444
at::Tensor &cos_sin_cache,
4545
bool is_neox) {
46+
TORCH_CHECK(key.has_value(), "rotary_embedding_meta: key must have a value");
4647
auto num_tokens = positions.sym_numel();
4748
auto query_hidden_size = query.sym_numel() / num_tokens;
48-
auto key_hidden_size = key.sym_numel() / num_tokens;
49+
auto key_hidden_size = key.value().sym_numel() / num_tokens;
4950

5051
auto num_heads = query_hidden_size / head_size;
5152
auto num_kv_heads = key_hidden_size / head_size;
52-
at::Tensor query_dst = at::empty_symint({num_tokens, num_heads, head_size}, query.options());
53-
at::Tensor key_dst = at::empty_symint({num_tokens, num_kv_heads, head_size}, key.options());
5453

55-
return {query_dst, key_dst};
54+
c10::SymIntArrayRef query_shape({num_tokens, num_heads, head_size});
55+
c10::SymIntArrayRef key_shape({num_tokens, num_kv_heads, head_size});
56+
query.resize__symint(query_shape);
57+
key.value().resize__symint(key_shape);
5658
}
5759

5860
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask_meta(
@@ -99,4 +101,4 @@ namespace {
99101
ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta);
100102

101103
}
102-
}
104+
}

csrc/utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
1414
// could be a macro instead of a literal token.
1515
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
16+
//
17+
// A version of the TORCH_LIBRARY_FRAGMENT macro that expands the NAME, i.e. so NAME
18+
// could be a macro instead of a literal token.
19+
#define TORCH_LIBRARY_FRAGMENT_EXPAND(NAME, MODULE) TORCH_LIBRARY_FRAGMENT(NAME, MODULE)
1620

1721
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
1822
// could be a macro instead of a literal token.

tests/e2e/singlecard/ops/test_rotary_embedding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def test_rotary_embedding_quant_with_leading_dim(
182182
)
183183

184184
ref_query, ref_key = rope.forward_native(positions, query, key)
185-
query, key = torch.ops._C.rotary_embedding(
185+
torch.ops._C.rotary_embedding(
186186
positions,
187187
query,
188188
key,
@@ -239,16 +239,16 @@ def forward(
239239
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
240240
qkv = self.qkv_proj(hidden_states)
241241
q, k, v = qkv.chunk(3, dim=-1)
242-
query, key = torch.ops._C.rotary_embedding(
242+
q_shape = q.shape
243+
torch.ops._C.rotary_embedding(
243244
positions,
244245
q,
245246
k,
246247
self.rope.head_size,
247248
self.rope.cos_sin_cache,
248249
self.rope.is_neox_style,
249250
)
250-
query = query.view(q.shape)
251-
key = key.view(k.shape)
251+
query = q.view(q_shape)
252252
o = self.o_proj(query)
253253
return o
254254

vllm_ascend/ops/rotary_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _rope_forward_oot(
5151
# adopt custom kernel path for rotary_embedding
5252
if _custom_rotary_embedding_enabled(query, neox_style,
5353
self.head_size) and not is_310p():
54-
query, key = torch.ops._C.rotary_embedding(
54+
torch.ops._C.rotary_embedding(
5555
positions,
5656
query,
5757
key,

vllm_ascend/torchair/ops/torchair_rotary_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def rope_forward_oot(
6262
# adopt custom kernel path for rotary_embedding
6363
if custom_rotary_embedding_enabled(query, neox_style,
6464
self.head_size) and not is_310p():
65-
query, key = torch.ops._C.rotary_embedding(
65+
torch.ops._C.rotary_embedding(
6666
positions,
6767
query,
6868
key,

0 commit comments

Comments
 (0)