@@ -38,7 +38,7 @@ AscendType get_dtype_from_torch(at::ScalarType scalarType)
38
38
}
39
39
}
40
40
41
- std::tuple<at::Tensor, at::Tensor> rotary_embedding (at::Tensor &positions, at::Tensor &query, at::Tensor & key,
41
+ void rotary_embedding (at::Tensor &positions, at::Tensor &query, std::optional< at::Tensor> key,
42
42
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
43
43
{
44
44
int32_t deviceId = 0 ;
@@ -47,22 +47,23 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
47
47
TORCH_CHECK (
48
48
positions_ndim == 1 || positions_ndim == 2 ,
49
49
" positions must have shape [num_tokens] or [batch_size, seq_len]" );
50
+ TORCH_CHECK (key.has_value (), " rotary_embedding: key must have a value" );
50
51
if (positions_ndim == 1 ) {
51
52
TORCH_CHECK (
52
- query.size (0 ) == positions.size (0 ) && key.size (0 ) == positions.size (0 ),
53
+ query.size (0 ) == positions.size (0 ) && key.value (). size (0 ) == positions.size (0 ),
53
54
" query, key and positions must have the same number of tokens" );
54
55
}
55
56
if (positions_ndim == 2 ) {
56
57
TORCH_CHECK (
57
58
query.size (0 ) == positions.size (0 ) &&
58
- key.size (0 ) == positions.size (0 ) &&
59
+ key.value (). size (0 ) == positions.size (0 ) &&
59
60
query.size (1 ) == positions.size (1 ) &&
60
- key.size (1 ) == positions.size (1 ),
61
+ key.value (). size (1 ) == positions.size (1 ),
61
62
" query, key and positions must have the same batch_size and seq_len" );
62
63
}
63
64
TORCH_CHECK (head_size % 32 == 0 , " rotary_embedding: headSize should be divisible by 32" );
64
65
int query_hidden_size = query.numel () / num_tokens;
65
- int key_hidden_size = key.numel () / num_tokens;
66
+ int key_hidden_size = key.value (). numel () / num_tokens;
66
67
TORCH_CHECK (query_hidden_size % head_size == 0 );
67
68
TORCH_CHECK (key_hidden_size % head_size == 0 );
68
69
TORCH_CHECK (is_neox == true , " rotary_embedding: neox=false is not supported as custom kernel in vllm-ascend" );
@@ -72,18 +73,18 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
72
73
int num_kv_heads = key_hidden_size / head_size;
73
74
TORCH_CHECK (num_heads % num_kv_heads == 0 );
74
75
at::Tensor query_dst = at::empty ({num_tokens, num_heads, head_size}, query.options ());
75
- at::Tensor key_dst = at::empty ({num_tokens, num_kv_heads, head_size}, key.options ());
76
+ at::Tensor key_dst = at::empty ({num_tokens, num_kv_heads, head_size}, key.value (). options ());
76
77
77
78
int rot_dim = cos_sin_cache.size (1 );
78
79
int seq_dim_idx = positions_ndim - 1 ;
79
80
int64_t *position_ids_ptr = positions.data_ptr <int64_t >();
80
81
void *query_dst_ptr = query_dst.data_ptr ();
81
82
void *key_dst_ptr = key_dst.data_ptr ();
82
83
void *query_ptr = query.data_ptr ();
83
- void *key_ptr = key.data_ptr ();
84
+ void *key_ptr = key.value (). data_ptr ();
84
85
void *cos_sin_cache_ptr = cos_sin_cache.data_ptr ();
85
86
int64_t query_stride = query.stride (seq_dim_idx);
86
- int64_t key_stride = key.stride (seq_dim_idx);
87
+ int64_t key_stride = key.value (). stride (seq_dim_idx);
87
88
int64_t dst_query_stride = query_dst.stride (0 );
88
89
int64_t dst_key_stride = key_dst.stride (0 );
89
90
at::ScalarType scalar_type = query.scalar_type ();
@@ -104,7 +105,9 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
104
105
return 0 ;
105
106
});
106
107
cmd.Run ();
107
- return {query_dst, key_dst};
108
+
109
+ query.copy_ (query_dst);
110
+ key.value ().copy_ (key_dst);
108
111
}
109
112
110
113
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask (
@@ -142,7 +145,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
142
145
TP2, rank 1:
143
146
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
144
147
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
145
- index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
148
+ index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
146
149
Parameters:
147
150
org_vocab_start_index //base embeddings start
148
151
org_vocab_end_index //base embeddings end
@@ -165,22 +168,22 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
165
168
// Create output tensors
166
169
at::Tensor masked_input = at::empty_like (input);
167
170
at::Tensor mask = at::empty_like (input).to (at::kBool );
168
-
171
+
169
172
// Get data pointers
170
173
void *input_ptr = input.data_ptr ();
171
174
void *masked_input_ptr = masked_input.data_ptr ();
172
175
void *mask_ptr = mask.data_ptr ();
173
-
176
+
174
177
// Get current stream
175
178
aclrtStream stream = c10_npu::getCurrentNPUStream ().stream ();
176
-
179
+
177
180
// Get scalar type
178
181
at::ScalarType scalar_type = input.scalar_type ();
179
-
182
+
180
183
// Create and configure OpCommand
181
184
at_npu::native::OpCommand cmd;
182
185
cmd.Name (" get_masked_input_and_mask" );
183
- cmd.SetCustomHandler ([scalar_type, size, stream,
186
+ cmd.SetCustomHandler ([scalar_type, size, stream,
184
187
input_ptr, masked_input_ptr, mask_ptr,
185
188
org_vocab_start_index, org_vocab_end_index,
186
189
num_org_vocab_padding, added_vocab_start_index,
@@ -194,7 +197,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
194
197
get_masked_input_and_mask_impl (
195
198
stream,
196
199
input_ptr,
197
- masked_input_ptr,
200
+ masked_input_ptr,
198
201
mask_ptr,
199
202
org_vocab_start_index,
200
203
org_vocab_end_index,
@@ -204,7 +207,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
204
207
size,
205
208
loop_cnt,
206
209
aiv_num);
207
-
210
+
208
211
return 0 ;
209
212
});
210
213
cmd.Run ();
@@ -321,8 +324,8 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at
321
324
aclrtStream stream = c10_npu::getCurrentNPUStream ().stream ();
322
325
at_npu::native::OpCommand cmd;
323
326
cmd.Name (" sgmv_shrink" );
324
- cmd.SetCustomHandler ([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size,
325
- seq_len_ptr, seq_len_size, y_ptr,
327
+ cmd.SetCustomHandler ([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size,
328
+ seq_len_ptr, seq_len_size, y_ptr,
326
329
batch_size, input_hidden_token, lora_rank, scale_f]() -> int {
327
330
auto dtype = get_dtype_from_torch (scalar_type);
328
331
int device_id = 0 ;
@@ -331,7 +334,7 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at
331
334
int num_tokens_per_core = (batch_size + aiv_num - 1 ) / aiv_num;
332
335
TORCH_CHECK (" num_tokens_per_core != 0" , " num_tokens_per_core should not be 0" );
333
336
sgmv_shrink_impl (dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size,
334
- y_ptr, batch_size,
337
+ y_ptr, batch_size,
335
338
num_tokens_per_core, input_hidden_token, lora_rank, scale_f);
336
339
return 0 ;
337
340
});
@@ -368,15 +371,15 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
368
371
aclrtStream stream = c10_npu::getCurrentNPUStream ().stream ();
369
372
at_npu::native::OpCommand cmd;
370
373
cmd.Name (" sgmv_expand" );
371
- cmd.SetCustomHandler ([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
374
+ cmd.SetCustomHandler ([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
372
375
batch_size, lora_rank, slice_offset, slice_size, output_full_dim]() -> int {
373
376
auto dtype = get_dtype_from_torch (scalar_type);
374
377
int device_id = 0 ;
375
378
int64_t aiv_num = 0 ;
376
379
TORCH_CHECK (aclGetDeviceCapability (device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
377
380
int num_tokens_per_core = (batch_size + aiv_num - 1 ) / aiv_num;
378
381
TORCH_CHECK (" num_tokens_per_core != 0" , " num_tokens_per_core should not be 0" );
379
- sgmv_expand_impl (dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
382
+ sgmv_expand_impl (dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
380
383
batch_size, num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
381
384
return 0 ;
382
385
});
@@ -385,43 +388,34 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
385
388
}
386
389
} // namespace vllm_ascend
387
390
388
- TORCH_LIBRARY_EXPAND (_C, ops)
391
+ TORCH_LIBRARY_FRAGMENT_EXPAND (_C, ops)
389
392
{
390
- // vLLM-Ascend custom ops
391
393
ops.def (" weak_ref_tensor(Tensor input) -> Tensor" );
392
- ops.impl (" weak_ref_tensor" , torch::kPrivateUse1 , &vllm_ascend::weak_ref_tensor);
393
-
394
- // Rotary embedding
395
- // Apply GPT-NeoX style rotary embedding to query and key.
396
- ops.def (
397
- " rotary_embedding(Tensor positions, Tensor! query,"
398
- " Tensor! key, int head_size,"
399
- " Tensor cos_sin_cache, bool is_neox) -> (Tensor query, Tensor key)" );
400
- ops.impl (" rotary_embedding" , torch::kPrivateUse1 , &vllm_ascend::rotary_embedding);
401
-
402
394
ops.def (
403
395
" get_masked_input_and_mask(Tensor input, "
404
396
" int org_vocab_start_index, "
405
397
" int org_vocab_end_index, "
406
398
" int num_org_vocab_padding, "
407
399
" int added_vocab_start_index, "
408
400
" int added_vocab_end_index) -> (Tensor masked_input, Tensor mask)" );
409
- ops.impl (" get_masked_input_and_mask" , torch::kPrivateUse1 , &vllm_ascend::get_masked_input_and_mask);
410
-
411
401
ops.def (" bgmv_shrink(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y, float scale) -> ()" );
412
- ops.impl (" bgmv_shrink" , torch::kPrivateUse1 , &vllm_ascend::bgmv_shrink);
413
-
414
402
ops.def (
415
403
" bgmv_expand(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y,"
416
404
" int slice_offset, int slice_size) -> Tensor" );
417
- ops.impl (" bgmv_expand" , torch::kPrivateUse1 , &vllm_ascend::bgmv_expand);
418
-
419
405
ops.def (" sgmv_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y, float scale) -> ()" );
420
- ops.impl (" sgmv_shrink" , torch::kPrivateUse1 , &vllm_ascend::sgmv_shrink);
421
-
422
406
ops.def (
423
407
" sgmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y,"
424
408
" int slice_offset, int slice_size) -> Tensor" );
409
+ }
410
+
411
+ TORCH_LIBRARY_IMPL_EXPAND (_C, PrivateUse1, ops)
412
+ {
413
+ ops.impl (" weak_ref_tensor" , torch::kPrivateUse1 , &vllm_ascend::weak_ref_tensor);
414
+ ops.impl (" rotary_embedding" , torch::kPrivateUse1 , &vllm_ascend::rotary_embedding);
415
+ ops.impl (" get_masked_input_and_mask" , torch::kPrivateUse1 , &vllm_ascend::get_masked_input_and_mask);
416
+ ops.impl (" bgmv_shrink" , torch::kPrivateUse1 , &vllm_ascend::bgmv_shrink);
417
+ ops.impl (" bgmv_expand" , torch::kPrivateUse1 , &vllm_ascend::bgmv_expand);
418
+ ops.impl (" sgmv_shrink" , torch::kPrivateUse1 , &vllm_ascend::sgmv_shrink);
425
419
ops.impl (" sgmv_expand" , torch::kPrivateUse1 , &vllm_ascend::sgmv_expand);
426
420
}
427
421
0 commit comments