@@ -38,7 +38,7 @@ AscendType get_dtype_from_torch(at::ScalarType scalarType)
38
38
}
39
39
}
40
40
41
- std::tuple<at::Tensor, at::Tensor> rotary_embedding (at::Tensor &positions, at::Tensor &query, at::Tensor & key,
41
+ void rotary_embedding (at::Tensor &positions, at::Tensor &query, std::optional< at::Tensor> key,
42
42
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
43
43
{
44
44
int32_t deviceId = 0 ;
@@ -47,22 +47,23 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
47
47
TORCH_CHECK (
48
48
positions_ndim == 1 || positions_ndim == 2 ,
49
49
" positions must have shape [num_tokens] or [batch_size, seq_len]" );
50
+ TORCH_CHECK (key.has_value (), " key must have value" );
50
51
if (positions_ndim == 1 ) {
51
52
TORCH_CHECK (
52
- query.size (0 ) == positions.size (0 ) && key.size (0 ) == positions.size (0 ),
53
+ query.size (0 ) == positions.size (0 ) && key.value (). size (0 ) == positions.size (0 ),
53
54
" query, key and positions must have the same number of tokens" );
54
55
}
55
56
if (positions_ndim == 2 ) {
56
57
TORCH_CHECK (
57
58
query.size (0 ) == positions.size (0 ) &&
58
- key.size (0 ) == positions.size (0 ) &&
59
+ key.value (). size (0 ) == positions.size (0 ) &&
59
60
query.size (1 ) == positions.size (1 ) &&
60
- key.size (1 ) == positions.size (1 ),
61
+ key.value (). size (1 ) == positions.size (1 ),
61
62
" query, key and positions must have the same batch_size and seq_len" );
62
63
}
63
64
TORCH_CHECK (head_size % 32 == 0 , " rotary_embedding: headSize should be divisible by 32" );
64
65
int query_hidden_size = query.numel () / num_tokens;
65
- int key_hidden_size = key.numel () / num_tokens;
66
+ int key_hidden_size = key.value (). numel () / num_tokens;
66
67
TORCH_CHECK (query_hidden_size % head_size == 0 );
67
68
TORCH_CHECK (key_hidden_size % head_size == 0 );
68
69
TORCH_CHECK (is_neox == true , " rotary_embedding: neox=false is not supported as custom kernel in vllm-ascend" );
@@ -72,18 +73,18 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
72
73
int num_kv_heads = key_hidden_size / head_size;
73
74
TORCH_CHECK (num_heads % num_kv_heads == 0 );
74
75
at::Tensor query_dst = at::empty ({num_tokens, num_heads, head_size}, query.options ());
75
- at::Tensor key_dst = at::empty ({num_tokens, num_kv_heads, head_size}, key.options ());
76
+ at::Tensor key_dst = at::empty ({num_tokens, num_kv_heads, head_size}, key.value (). options ());
76
77
77
78
int rot_dim = cos_sin_cache.size (1 );
78
79
int seq_dim_idx = positions_ndim - 1 ;
79
80
int64_t *position_ids_ptr = positions.data_ptr <int64_t >();
80
81
void *query_dst_ptr = query_dst.data_ptr ();
81
82
void *key_dst_ptr = key_dst.data_ptr ();
82
83
void *query_ptr = query.data_ptr ();
83
- void *key_ptr = key.data_ptr ();
84
+ void *key_ptr = key.value (). data_ptr ();
84
85
void *cos_sin_cache_ptr = cos_sin_cache.data_ptr ();
85
86
int64_t query_stride = query.stride (seq_dim_idx);
86
- int64_t key_stride = key.stride (seq_dim_idx);
87
+ int64_t key_stride = key.value (). stride (seq_dim_idx);
87
88
int64_t dst_query_stride = query_dst.stride (0 );
88
89
int64_t dst_key_stride = key_dst.stride (0 );
89
90
at::ScalarType scalar_type = query.scalar_type ();
@@ -104,7 +105,9 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
104
105
return 0 ;
105
106
});
106
107
cmd.Run ();
107
- return {query_dst, key_dst};
108
+
109
+ query.copy_ (query_dst);
110
+ key.value ().copy_ (key_dst);
108
111
}
109
112
110
113
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask (
@@ -385,43 +388,36 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
385
388
}
386
389
} // namespace vllm_ascend
387
390
388
- TORCH_LIBRARY_EXPAND (_C, ops)
391
+ TORCH_LIBRARY_FRAGMENT_EXPAND (_C, ops)
389
392
{
390
- // vLLM-Ascend custom ops
391
- ops.def (" weak_ref_tensor(Tensor input) -> Tensor" );
392
- ops.impl (" weak_ref_tensor" , torch::kPrivateUse1 , &vllm_ascend::weak_ref_tensor);
393
-
394
- // Rotary embedding
395
- // Apply GPT-NeoX style rotary embedding to query and key.
396
- ops.def (
397
- " rotary_embedding(Tensor positions, Tensor! query,"
398
- " Tensor! key, int head_size,"
399
- " Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)" );
400
- ops.impl (" rotary_embedding" , torch::kPrivateUse1 , &vllm_ascend::rotary_embedding);
401
-
402
393
ops.def (
403
394
" get_masked_input_and_mask(Tensor input, "
404
395
" int org_vocab_start_index, "
405
396
" int org_vocab_end_index, "
406
397
" int num_org_vocab_padding, "
407
398
" int added_vocab_start_index, "
408
399
" int added_vocab_end_index) -> (Tensor masked_input, Tensor mask)" );
409
- ops.impl (" get_masked_input_and_mask" , torch::kPrivateUse1 , &vllm_ascend::get_masked_input_and_mask);
410
-
411
400
ops.def (" bgmv_shrink(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y, float scale) -> ()" );
412
- ops.impl (" bgmv_shrink" , torch::kPrivateUse1 , &vllm_ascend::bgmv_shrink);
413
-
414
401
ops.def (
415
402
" bgmv_expand(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y,"
416
403
" int slice_offset, int slice_size) -> Tensor" );
417
- ops.impl (" bgmv_expand" , torch::kPrivateUse1 , &vllm_ascend::bgmv_expand);
418
-
419
404
ops.def (" sgmv_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y, float scale) -> ()" );
420
- ops.impl (" sgmv_shrink" , torch::kPrivateUse1 , &vllm_ascend::sgmv_shrink);
421
-
422
405
ops.def (
423
406
" sgmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y,"
424
407
" int slice_offset, int slice_size) -> Tensor" );
408
+ }
409
+
410
+ TORCH_LIBRARY_IMPL_EXPAND (_C, PrivateUse1, ops)
411
+ {
412
+ // vLLM-Ascend custom ops
413
+ ops.impl (" weak_ref_tensor" , torch::kPrivateUse1 , &vllm_ascend::weak_ref_tensor);
414
+ // Rotary embedding
415
+ // Apply GPT-NeoX style rotary embedding to query and key.
416
+ ops.impl (" rotary_embedding" , torch::kPrivateUse1 , &vllm_ascend::rotary_embedding);
417
+ ops.impl (" get_masked_input_and_mask" , torch::kPrivateUse1 , &vllm_ascend::get_masked_input_and_mask);
418
+ ops.impl (" bgmv_shrink" , torch::kPrivateUse1 , &vllm_ascend::bgmv_shrink);
419
+ ops.impl (" bgmv_expand" , torch::kPrivateUse1 , &vllm_ascend::bgmv_expand);
420
+ ops.impl (" sgmv_shrink" , torch::kPrivateUse1 , &vllm_ascend::sgmv_shrink);
425
421
ops.impl (" sgmv_expand" , torch::kPrivateUse1 , &vllm_ascend::sgmv_expand);
426
422
}
427
423
0 commit comments