@@ -38,7 +38,7 @@ AscendType get_dtype_from_torch(at::ScalarType scalarType)
3838 }
3939}
4040
41- std::tuple<at::Tensor, at::Tensor> rotary_embedding (at::Tensor &positions, at::Tensor &query, at::Tensor & key,
41+ void rotary_embedding (at::Tensor &positions, at::Tensor &query, std::optional< at::Tensor> key,
4242 int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
4343{
4444 int32_t deviceId = 0 ;
@@ -47,22 +47,23 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
4747 TORCH_CHECK (
4848 positions_ndim == 1 || positions_ndim == 2 ,
4949 " positions must have shape [num_tokens] or [batch_size, seq_len]" );
50+ TORCH_CHECK (key.has_value (), " rotary_embedding: key must have a value" );
5051 if (positions_ndim == 1 ) {
5152 TORCH_CHECK (
52- query.size (0 ) == positions.size (0 ) && key.size (0 ) == positions.size (0 ),
53+ query.size (0 ) == positions.size (0 ) && key.value (). size (0 ) == positions.size (0 ),
5354 " query, key and positions must have the same number of tokens" );
5455 }
5556 if (positions_ndim == 2 ) {
5657 TORCH_CHECK (
5758 query.size (0 ) == positions.size (0 ) &&
58- key.size (0 ) == positions.size (0 ) &&
59+ key.value (). size (0 ) == positions.size (0 ) &&
5960 query.size (1 ) == positions.size (1 ) &&
60- key.size (1 ) == positions.size (1 ),
61+ key.value (). size (1 ) == positions.size (1 ),
6162 " query, key and positions must have the same batch_size and seq_len" );
6263 }
6364 TORCH_CHECK (head_size % 32 == 0 , " rotary_embedding: headSize should be divisible by 32" );
6465 int query_hidden_size = query.numel () / num_tokens;
65- int key_hidden_size = key.numel () / num_tokens;
66+ int key_hidden_size = key.value (). numel () / num_tokens;
6667 TORCH_CHECK (query_hidden_size % head_size == 0 );
6768 TORCH_CHECK (key_hidden_size % head_size == 0 );
6869 TORCH_CHECK (is_neox == true , " rotary_embedding: neox=false is not supported as custom kernel in vllm-ascend" );
@@ -72,18 +73,18 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
7273 int num_kv_heads = key_hidden_size / head_size;
7374 TORCH_CHECK (num_heads % num_kv_heads == 0 );
7475 at::Tensor query_dst = at::empty ({num_tokens, num_heads, head_size}, query.options ());
75- at::Tensor key_dst = at::empty ({num_tokens, num_kv_heads, head_size}, key.options ());
76+ at::Tensor key_dst = at::empty ({num_tokens, num_kv_heads, head_size}, key.value (). options ());
7677
7778 int rot_dim = cos_sin_cache.size (1 );
7879 int seq_dim_idx = positions_ndim - 1 ;
7980 int64_t *position_ids_ptr = positions.data_ptr <int64_t >();
8081 void *query_dst_ptr = query_dst.data_ptr ();
8182 void *key_dst_ptr = key_dst.data_ptr ();
8283 void *query_ptr = query.data_ptr ();
83- void *key_ptr = key.data_ptr ();
84+ void *key_ptr = key.value (). data_ptr ();
8485 void *cos_sin_cache_ptr = cos_sin_cache.data_ptr ();
8586 int64_t query_stride = query.stride (seq_dim_idx);
86- int64_t key_stride = key.stride (seq_dim_idx);
87+ int64_t key_stride = key.value (). stride (seq_dim_idx);
8788 int64_t dst_query_stride = query_dst.stride (0 );
8889 int64_t dst_key_stride = key_dst.stride (0 );
8990 at::ScalarType scalar_type = query.scalar_type ();
@@ -104,7 +105,9 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
104105 return 0 ;
105106 });
106107 cmd.Run ();
107- return {query_dst, key_dst};
108+
109+ query.copy_ (query_dst);
110+ key.value ().copy_ (key_dst);
108111}
109112
110113std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask (
@@ -385,43 +388,36 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
385388}
386389} // namespace vllm_ascend
387390
388- TORCH_LIBRARY_EXPAND (_C, ops)
391+ TORCH_LIBRARY_FRAGMENT_EXPAND (_C, ops)
389392{
390- // vLLM-Ascend custom ops
391- ops.def (" weak_ref_tensor(Tensor input) -> Tensor" );
392- ops.impl (" weak_ref_tensor" , torch::kPrivateUse1 , &vllm_ascend::weak_ref_tensor);
393-
394- // Rotary embedding
395- // Apply GPT-NeoX style rotary embedding to query and key.
396- ops.def (
397- " rotary_embedding(Tensor positions, Tensor! query,"
398- " Tensor! key, int head_size,"
399- " Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)" );
400- ops.impl (" rotary_embedding" , torch::kPrivateUse1 , &vllm_ascend::rotary_embedding);
401-
402393 ops.def (
403394 " get_masked_input_and_mask(Tensor input, "
404395 " int org_vocab_start_index, "
405396 " int org_vocab_end_index, "
406397 " int num_org_vocab_padding, "
407398 " int added_vocab_start_index, "
408399 " int added_vocab_end_index) -> (Tensor masked_input, Tensor mask)" );
409- ops.impl (" get_masked_input_and_mask" , torch::kPrivateUse1 , &vllm_ascend::get_masked_input_and_mask);
410-
411400 ops.def (" bgmv_shrink(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y, float scale) -> ()" );
412- ops.impl (" bgmv_shrink" , torch::kPrivateUse1 , &vllm_ascend::bgmv_shrink);
413-
414401 ops.def (
415402 " bgmv_expand(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y,"
416403 " int slice_offset, int slice_size) -> Tensor" );
417- ops.impl (" bgmv_expand" , torch::kPrivateUse1 , &vllm_ascend::bgmv_expand);
418-
419404 ops.def (" sgmv_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y, float scale) -> ()" );
420- ops.impl (" sgmv_shrink" , torch::kPrivateUse1 , &vllm_ascend::sgmv_shrink);
421-
422405 ops.def (
423406 " sgmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y,"
424407 " int slice_offset, int slice_size) -> Tensor" );
408+ }
409+
410+ TORCH_LIBRARY_IMPL_EXPAND (_C, PrivateUse1, ops)
411+ {
412+ // vLLM-Ascend custom ops
413+ ops.impl (" weak_ref_tensor" , torch::kPrivateUse1 , &vllm_ascend::weak_ref_tensor);
414+ // Rotary embedding
415+ // Apply GPT-NeoX style rotary embedding to query and key.
416+ ops.impl (" rotary_embedding" , torch::kPrivateUse1 , &vllm_ascend::rotary_embedding);
417+ ops.impl (" get_masked_input_and_mask" , torch::kPrivateUse1 , &vllm_ascend::get_masked_input_and_mask);
418+ ops.impl (" bgmv_shrink" , torch::kPrivateUse1 , &vllm_ascend::bgmv_shrink);
419+ ops.impl (" bgmv_expand" , torch::kPrivateUse1 , &vllm_ascend::bgmv_expand);
420+ ops.impl (" sgmv_shrink" , torch::kPrivateUse1 , &vllm_ascend::sgmv_shrink);
425421 ops.impl (" sgmv_expand" , torch::kPrivateUse1 , &vllm_ascend::sgmv_expand);
426422}
427423
0 commit comments