Skip to content

Commit 9b0e91d

Browse files
committed
Fix the bugs about operator registration by PyTorch Dispatcher
As the title stated. Signed-off-by: fffrog <[email protected]>
1 parent 7d47d8f commit 9b0e91d

File tree

6 files changed

+45
-46
lines changed

6 files changed

+45
-46
lines changed

csrc/torch_binding.cpp

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ AscendType get_dtype_from_torch(at::ScalarType scalarType)
3838
}
3939
}
4040

41-
std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
41+
void rotary_embedding(at::Tensor &positions, at::Tensor &query, std::optional<at::Tensor> key,
4242
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
4343
{
4444
int32_t deviceId = 0;
@@ -47,22 +47,23 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
4747
TORCH_CHECK(
4848
positions_ndim == 1 || positions_ndim == 2,
4949
"positions must have shape [num_tokens] or [batch_size, seq_len]");
50+
TORCH_CHECK(key.has_value(), "key must have value");
5051
if (positions_ndim == 1) {
5152
TORCH_CHECK(
52-
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
53+
query.size(0) == positions.size(0) && key.value().size(0) == positions.size(0),
5354
"query, key and positions must have the same number of tokens");
5455
}
5556
if (positions_ndim == 2) {
5657
TORCH_CHECK(
5758
query.size(0) == positions.size(0) &&
58-
key.size(0) == positions.size(0) &&
59+
key.value().size(0) == positions.size(0) &&
5960
query.size(1) == positions.size(1) &&
60-
key.size(1) == positions.size(1),
61+
key.value().size(1) == positions.size(1),
6162
"query, key and positions must have the same batch_size and seq_len");
6263
}
6364
TORCH_CHECK(head_size % 32 == 0, "rotary_embedding: headSize should be divisible by 32");
6465
int query_hidden_size = query.numel() / num_tokens;
65-
int key_hidden_size = key.numel() / num_tokens;
66+
int key_hidden_size = key.value().numel() / num_tokens;
6667
TORCH_CHECK(query_hidden_size % head_size == 0);
6768
TORCH_CHECK(key_hidden_size % head_size == 0);
6869
TORCH_CHECK(is_neox == true, "rotary_embedding: neox=false is not supported as custom kernel in vllm-ascend");
@@ -72,18 +73,18 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
7273
int num_kv_heads = key_hidden_size / head_size;
7374
TORCH_CHECK(num_heads % num_kv_heads == 0);
7475
at::Tensor query_dst = at::empty({num_tokens, num_heads, head_size}, query.options());
75-
at::Tensor key_dst = at::empty({num_tokens, num_kv_heads, head_size}, key.options());
76+
at::Tensor key_dst = at::empty({num_tokens, num_kv_heads, head_size}, key.value().options());
7677

7778
int rot_dim = cos_sin_cache.size(1);
7879
int seq_dim_idx = positions_ndim - 1;
7980
int64_t *position_ids_ptr = positions.data_ptr<int64_t>();
8081
void *query_dst_ptr = query_dst.data_ptr();
8182
void *key_dst_ptr = key_dst.data_ptr();
8283
void *query_ptr = query.data_ptr();
83-
void *key_ptr = key.data_ptr();
84+
void *key_ptr = key.value().data_ptr();
8485
void *cos_sin_cache_ptr = cos_sin_cache.data_ptr();
8586
int64_t query_stride = query.stride(seq_dim_idx);
86-
int64_t key_stride = key.stride(seq_dim_idx);
87+
int64_t key_stride = key.value().stride(seq_dim_idx);
8788
int64_t dst_query_stride = query_dst.stride(0);
8889
int64_t dst_key_stride = key_dst.stride(0);
8990
at::ScalarType scalar_type = query.scalar_type();
@@ -104,7 +105,9 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
104105
return 0;
105106
});
106107
cmd.Run();
107-
return {query_dst, key_dst};
108+
109+
query.copy_(query_dst);
110+
key.value().copy_(key_dst);
108111
}
109112

110113
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
@@ -385,43 +388,36 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
385388
}
386389
} // namespace vllm_ascend
387390

388-
TORCH_LIBRARY_EXPAND(_C, ops)
391+
TORCH_LIBRARY_FRAGMENT_EXPAND(_C, ops)
389392
{
390-
// vLLM-Ascend custom ops
391-
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
392-
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
393-
394-
// Rotary embedding
395-
// Apply GPT-NeoX style rotary embedding to query and key.
396-
ops.def(
397-
"rotary_embedding(Tensor positions, Tensor! query,"
398-
" Tensor! key, int head_size,"
399-
" Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)");
400-
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
401-
402393
ops.def(
403394
"get_masked_input_and_mask(Tensor input, "
404395
" int org_vocab_start_index, "
405396
" int org_vocab_end_index, "
406397
" int num_org_vocab_padding, "
407398
" int added_vocab_start_index, "
408399
" int added_vocab_end_index) -> (Tensor masked_input, Tensor mask)");
409-
ops.impl("get_masked_input_and_mask", torch::kPrivateUse1, &vllm_ascend::get_masked_input_and_mask);
410-
411400
ops.def("bgmv_shrink(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y, float scale) -> ()");
412-
ops.impl("bgmv_shrink", torch::kPrivateUse1, &vllm_ascend::bgmv_shrink);
413-
414401
ops.def(
415402
"bgmv_expand(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y,"
416403
" int slice_offset, int slice_size) -> Tensor");
417-
ops.impl("bgmv_expand", torch::kPrivateUse1, &vllm_ascend::bgmv_expand);
418-
419404
ops.def("sgmv_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y, float scale) -> ()");
420-
ops.impl("sgmv_shrink", torch::kPrivateUse1, &vllm_ascend::sgmv_shrink);
421-
422405
ops.def(
423406
"sgmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y,"
424407
" int slice_offset, int slice_size) -> Tensor");
408+
}
409+
410+
TORCH_LIBRARY_IMPL_EXPAND(_C, PrivateUse1, ops)
411+
{
412+
// vLLM-Ascend custom ops
413+
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
414+
// Rotary embedding
415+
// Apply GPT-NeoX style rotary embedding to query and key.
416+
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
417+
ops.impl("get_masked_input_and_mask", torch::kPrivateUse1, &vllm_ascend::get_masked_input_and_mask);
418+
ops.impl("bgmv_shrink", torch::kPrivateUse1, &vllm_ascend::bgmv_shrink);
419+
ops.impl("bgmv_expand", torch::kPrivateUse1, &vllm_ascend::bgmv_expand);
420+
ops.impl("sgmv_shrink", torch::kPrivateUse1, &vllm_ascend::sgmv_shrink);
425421
ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand);
426422
}
427423

csrc/torch_binding_meta.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,24 @@
3636
namespace vllm_ascend {
3737
namespace meta {
3838

39-
std::tuple<at::Tensor, at::Tensor> rotary_embedding_meta(
39+
void rotary_embedding_meta(
4040
at::Tensor &positions,
4141
at::Tensor &query,
42-
at::Tensor &key,
42+
std::optional<at::Tensor> key,
4343
int64_t head_size,
4444
at::Tensor &cos_sin_cache,
4545
bool is_neox) {
4646
auto num_tokens = positions.sym_numel();
4747
auto query_hidden_size = query.sym_numel() / num_tokens;
48-
auto key_hidden_size = key.sym_numel() / num_tokens;
48+
auto key_hidden_size = key.value().sym_numel() / num_tokens;
4949

5050
auto num_heads = query_hidden_size / head_size;
5151
auto num_kv_heads = key_hidden_size / head_size;
52-
at::Tensor query_dst = at::empty_symint({num_tokens, num_heads, head_size}, query.options());
53-
at::Tensor key_dst = at::empty_symint({num_tokens, num_kv_heads, head_size}, key.options());
5452

55-
return {query_dst, key_dst};
53+
c10::SymIntArrayRef query_shape({num_tokens, num_heads, head_size});
54+
c10::SymIntArrayRef key_shape({num_tokens, num_kv_heads, head_size});
55+
query.resize__symint(query_shape);
56+
key.value().resize__symint(key_shape);
5657
}
5758

5859
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask_meta(
@@ -99,4 +100,4 @@ namespace {
99100
ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta);
100101

101102
}
102-
}
103+
}

csrc/utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
1414
// could be a macro instead of a literal token.
1515
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
16+
//
17+
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
18+
// could be a macro instead of a literal token.
19+
#define TORCH_LIBRARY_FRAGMENT_EXPAND(NAME, MODULE) TORCH_LIBRARY_FRAGMENT(NAME, MODULE)
1620

1721
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
1822
// could be a macro instead of a literal token.

tests/e2e/singlecard/ops/test_rotary_embedding.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def test_rotary_embedding_quant_with_leading_dim(
182182
)
183183

184184
ref_query, ref_key = rope.forward_native(positions, query, key)
185-
query, key = torch.ops._C.rotary_embedding(
185+
torch.ops._C.rotary_embedding(
186186
positions,
187187
query,
188188
key,
@@ -239,17 +239,15 @@ def forward(
239239
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
240240
qkv = self.qkv_proj(hidden_states)
241241
q, k, v = qkv.chunk(3, dim=-1)
242-
query, key = torch.ops._C.rotary_embedding(
242+
torch.ops._C.rotary_embedding(
243243
positions,
244244
q,
245245
k,
246246
self.rope.head_size,
247247
self.rope.cos_sin_cache,
248248
self.rope.is_neox_style,
249249
)
250-
query = query.view(q.shape)
251-
key = key.view(k.shape)
252-
o = self.o_proj(query)
250+
o = self.o_proj(q)
253251
return o
254252

255253

vllm_ascend/ops/rotary_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ def _rope_forward_oot(
5151
# adopt custom kernel path for rotary_embedding
5252
if _custom_rotary_embedding_enabled(query, neox_style,
5353
self.head_size) and not is_310p():
54-
query, key = torch.ops._C.rotary_embedding(
54+
torch.ops._C.rotary_embedding(
5555
positions,
5656
query,
5757
key,
5858
self.head_size,
5959
self.cos_sin_cache,
6060
neox_style,
6161
)
62-
return query.view(query_shape), key.view(key_shape)
62+
return query, key
6363
if offsets is not None:
6464
raise NotImplementedError(
6565
"Batched rotary embedding is currently not supported on NPU.")

vllm_ascend/torchair/ops/torchair_rotary_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,15 @@ def rope_forward_oot(
6262
# adopt custom kernel path for rotary_embedding
6363
if custom_rotary_embedding_enabled(query, neox_style,
6464
self.head_size) and not is_310p():
65-
query, key = torch.ops._C.rotary_embedding(
65+
torch.ops._C.rotary_embedding(
6666
positions,
6767
query,
6868
key,
6969
self.head_size,
7070
self.cos_sin_cache,
7171
neox_style,
7272
)
73-
return query.view(query_shape), key.view(key_shape)
73+
return query, key
7474
if offsets is not None:
7575
raise NotImplementedError(
7676
"Batched rotary embedding is currently not supported on NPU.")

0 commit comments

Comments
 (0)