Skip to content

Commit 363896a

Browse files
committed
feat(lora) Add lora shrink & expand ops
Signed-off-by: chzhang <[email protected]>
1 parent 925b092 commit 363896a

File tree

11 files changed

+166
-190
lines changed

11 files changed

+166
-190
lines changed

CMakeLists.txt

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ define_gpu_extension_target(
183183
if(VLLM_GPU_LANG STREQUAL "SYCL")
184184
set(VLLM_EXT_XPU_SRC
185185
"csrc/xpu/torch_bindings.cpp"
186+
"csrc/xpu/lora/lora_shrink.cpp"
187+
"csrc/xpu/lora/lora_expand.cpp"
186188
)
187189
include_directories("/usr/include")
188190
set(CMPLR_ROOT $ENV{CMPLR_ROOT})
@@ -236,27 +238,3 @@ define_gpu_extension_target(
236238
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
237239
USE_SABI 3
238240
WITH_SOABI)
239-
240-
#
241-
# _lora_C extension
242-
#
243-
244-
set(VLLM_LORA_EXT_SRC
245-
"csrc/lora/torch_bindings.cpp"
246-
"csrc/lora/lora_shrink.cpp"
247-
"csrc/lora/lora_expand.cpp"
248-
)
249-
250-
message(STATUS "Enabling lora extension.")
251-
define_gpu_extension_target(
252-
_lora_C
253-
DESTINATION vllm_xpu_kernels
254-
LANGUAGE ${VLLM_GPU_LANG}
255-
SOURCES ${VLLM_LORA_EXT_SRC}
256-
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
257-
LINK_FLAGS ${VLLM_GPU_LINK_FLAGS}
258-
ARCHITECTURES ${VLLM_GPU_ARCHES}
259-
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
260-
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
261-
USE_SABI 3
262-
WITH_SOABI)

benchmark/benchmark_lora.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
23

34
import argparse
45
import copy
@@ -17,7 +18,7 @@
1718
from utils import ArgPool, Bench, CudaGraphBenchParams
1819
from weight_shapes import WEIGHT_SHAPES
1920

20-
from tests.register_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
21+
from tests.lora.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
2122

2223
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
2324
DEFAULT_TP_SIZES = [1]

csrc/lora/torch_bindings.cpp

Lines changed: 0 additions & 23 deletions
This file was deleted.

csrc/lora/lora_expand.cpp renamed to csrc/xpu/lora/lora_expand.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,16 +166,16 @@ class bgmv_expand_kernel {
166166
sycl::group_barrier(sg);
167167

168168
if (vec_id == 0) {
169-
accscalar_t result = 0;
169+
float result = 0.0f;
170170
#pragma unroll
171171
for (uint32_t i = 0; i < workitem_per_hidden_; ++i) {
172-
result += slm_[slm_base + i];
172+
result += static_cast<float>(slm_[slm_base + i]);
173173
}
174174
const size_t out_off = static_cast<size_t>(batch_id) * output_hidden_ +
175175
slice_offset_ + hidden_id;
176176
if (add_to_output_) {
177-
outputs_[out_off] = static_cast<output_t>(
178-
static_cast<accscalar_t>(outputs_[out_off]) + result);
177+
result += static_cast<float>(outputs_[out_off]);
178+
outputs_[out_off] = static_cast<output_t>(result);
179179
} else {
180180
outputs_[out_off] = static_cast<output_t>(result);
181181
}
File renamed without changes.
File renamed without changes.

csrc/xpu/torch_bindings.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "core/registration.h"
22
#include "xpu/ops.h"
3+
#include "xpu/lora/lora_ops.h"
34

45
#include <torch/library.h>
56
#include <torch/version.h>
@@ -11,6 +12,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, xpu_ops) {
1112
"fp8_gemm_w8a16(Tensor! A, Tensor! B, bool trans_B, Tensor? B_scale_, "
1213
"Tensor? bias_) -> Tensor");
1314
xpu_ops.impl("fp8_gemm_w8a16", torch::kXPU, &fp8_gemm_w8a16);
15+
16+
xpu_ops.def(
17+
"bgmv_shrink(Tensor! outputs, Tensor inputs, Tensor weights, Tensor "
18+
"indices, float scale) -> ()");
19+
xpu_ops.impl("bgmv_shrink", torch::kXPU, &bgmv_shrink);
20+
21+
xpu_ops.def(
22+
"bgmv_expand(Tensor! outputs, Tensor inputs, Tensor weights, Tensor "
23+
"indices, bool add_to_output) -> ()");
24+
xpu_ops.impl("bgmv_expand", torch::kXPU, &bgmv_expand);
25+
26+
xpu_ops.def(
27+
"bgmv_expand_slice(Tensor! outputs, Tensor inputs, Tensor weights, "
28+
"Tensor indices, int slice_offset,bool add_to_output) -> ()");
29+
xpu_ops.impl("bgmv_expand_slice", torch::kXPU, &bgmv_expand_slice);
1430
}
1531

1632
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,6 @@ def run(self):
272272
if _build_custom_ops():
273273
ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._C"))
274274
ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._moe_C"))
275-
ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._lora_C"))
276275
ext_modules.append(CMakeExtension(name="vllm_xpu_kernels._xpu_C"))
277276

278277
if ext_modules:

tests/lora/lora_ops.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import torch
3+
4+
import vllm_xpu_kernels._xpu_C # noqa: F401
5+
6+
7+
def bgmv_shrink(
8+
inputs: torch.Tensor,
9+
lora_a_weights: torch.Tensor,
10+
output_tensor: torch.Tensor,
11+
lora_indices_tensor: torch.Tensor,
12+
scaling: float = 1.0,
13+
) -> None:
14+
torch.ops._xpu_C.bgmv_shrink(
15+
output_tensor,
16+
inputs,
17+
lora_a_weights,
18+
lora_indices_tensor,
19+
scaling,
20+
)
21+
22+
23+
def bgmv_expand(
24+
inputs: torch.Tensor,
25+
lora_b_weights: torch.Tensor,
26+
output_tensor: torch.Tensor,
27+
lora_indices_tensor: torch.Tensor,
28+
add_inputs: bool = True,
29+
):
30+
"""
31+
Args:
32+
inputs (torch.Tensor): Shape: `[batch_size, hidden_size]`.
33+
lora_b_weights (torch.Tensor): Shape: `[lora_num, rank, hidden_size]`.
34+
output_tensor (torch.Tensor): Shape: `[batch_size, rank]`.
35+
lora_indices_tensor (torch.Tensor): Shape: `[batch_size]`.
36+
The LoRA index corresponding to each batch. An index of -1 means
37+
no lora should be applied.
38+
add_inputs (bool, optional): Defaults to False. adds the final lora
39+
results to the output.
40+
41+
Semantics:
42+
for i in range(inputs.size(0)):
43+
output_tensor[i] =
44+
inputs[i] @ lora_b_weights[lora_indices_tensor[i]]
45+
+ (inputs[i] if add_inputs else 0)
46+
"""
47+
torch.ops._xpu_C.bgmv_expand(
48+
output_tensor,
49+
inputs,
50+
lora_b_weights,
51+
lora_indices_tensor,
52+
add_inputs,
53+
)
54+
55+
56+
def bgmv_expand_slice(
57+
inputs: torch.Tensor,
58+
lora_b_weights: torch.Tensor,
59+
output_tensor: torch.Tensor,
60+
lora_indices_tensor: torch.Tensor,
61+
slice_offset: int,
62+
slice_size: int,
63+
add_inputs: bool = True,
64+
):
65+
"""
66+
Args:
67+
inputs (torch.Tensor): Shape: `[batch_size, hidden_size]`.
68+
lora_b_weights (torch.Tensor): Shape: `[lora_num, rank, hidden_size]`.
69+
output_tensor (torch.Tensor): Shape: `[batch_size, rank]`.
70+
lora_indices_tensor (torch.Tensor): Shape: `[batch_size]`.
71+
The LoRA index
72+
corresponding to each batch. An index of -1 means no lora should be
73+
applied.
74+
slice_offset (int): output_tensor's offset
75+
slice_size (int): current output_tensor's size
76+
add_inputs (bool, optional): Defaults to False. adds the final lora
77+
results to the output.
78+
79+
Semantics:
80+
for i in range(inputs.size(0)):
81+
output_tensor[i][slice_offset:slice_offset+slice_size] =
82+
inputs[i] @ lora_b_weights[lora_indices_tensor[i]]
83+
+ (inputs[i] if add_inputs else 0)
84+
"""
85+
torch.ops._xpu_C.bgmv_expand_slice(
86+
output_tensor,
87+
inputs,
88+
lora_b_weights,
89+
lora_indices_tensor,
90+
slice_offset,
91+
add_inputs,
92+
)

tests/register_ops.py

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
import vllm_xpu_kernels._C # noqa: F401
77
import vllm_xpu_kernels._moe_C # noqa: F401
8-
import vllm_xpu_kernels._lora_C # noqa: F401
98

109

1110
# layer norm ops
@@ -162,91 +161,3 @@ def swigluoai_and_mul(
162161
# moe
163162
def moe_sum(input: torch.Tensor, output: torch.Tensor) -> None:
164163
torch.ops._moe_C.moe_sum(input, output)
165-
166-
167-
def bgmv_shrink(
168-
inputs: torch.Tensor,
169-
lora_a_weights: torch.Tensor,
170-
output_tensor: torch.Tensor,
171-
lora_indices_tensor: torch.Tensor,
172-
scaling: float = 1.0,
173-
) -> None:
174-
torch.ops._lora_C.bgmv_shrink(
175-
output_tensor,
176-
inputs,
177-
lora_a_weights,
178-
lora_indices_tensor,
179-
scaling,
180-
)
181-
182-
183-
def bgmv_expand(
184-
inputs: torch.Tensor,
185-
lora_b_weights: torch.Tensor,
186-
output_tensor: torch.Tensor,
187-
lora_indices_tensor: torch.Tensor,
188-
add_inputs: bool = True,
189-
):
190-
"""
191-
Args:
192-
inputs (torch.Tensor): Shape: `[batch_size, hidden_size]`.
193-
lora_b_weights (torch.Tensor): Shape: `[lora_num, rank, hidden_size]`.
194-
output_tensor (torch.Tensor): Shape: `[batch_size, rank]`.
195-
lora_indices_tensor (torch.Tensor): Shape: `[batch_size]`.
196-
The LoRA index corresponding to each batch. An index of -1 means
197-
no lora should be applied.
198-
add_inputs (bool, optional): Defaults to False. adds the final lora
199-
results to the output.
200-
201-
Semantics:
202-
for i in range(inputs.size(0)):
203-
output_tensor[i] =
204-
inputs[i] @ lora_b_weights[lora_indices_tensor[i]]
205-
+ (inputs[i] if add_inputs else 0)
206-
"""
207-
torch.ops._lora_C.bgmv_expand(
208-
output_tensor,
209-
inputs,
210-
lora_b_weights,
211-
lora_indices_tensor,
212-
add_inputs,
213-
)
214-
215-
216-
def bgmv_expand_slice(
217-
inputs: torch.Tensor,
218-
lora_b_weights: torch.Tensor,
219-
output_tensor: torch.Tensor,
220-
lora_indices_tensor: torch.Tensor,
221-
slice_offset: int,
222-
slice_size: int,
223-
add_inputs: bool = True,
224-
):
225-
"""
226-
Args:
227-
inputs (torch.Tensor): Shape: `[batch_size, hidden_size]`.
228-
lora_b_weights (torch.Tensor): Shape: `[lora_num, rank, hidden_size]`.
229-
output_tensor (torch.Tensor): Shape: `[batch_size, rank]`.
230-
lora_indices_tensor (torch.Tensor): Shape: `[batch_size]`.
231-
The LoRA index
232-
corresponding to each batch. An index of -1 means no lora should be
233-
applied.
234-
slice_offset (int): output_tensor's offset
235-
slice_size (int): current output_tensor's size
236-
add_inputs (bool, optional): Defaults to False. adds the final lora
237-
results to the output.
238-
239-
Semantics:
240-
for i in range(inputs.size(0)):
241-
output_tensor[i][slice_offset:slice_offset+slice_size] =
242-
inputs[i] @ lora_b_weights[lora_indices_tensor[i]]
243-
+ (inputs[i] if add_inputs else 0)
244-
"""
245-
torch.ops._lora_C.bgmv_expand_slice(
246-
output_tensor,
247-
inputs,
248-
lora_b_weights,
249-
lora_indices_tensor,
250-
slice_offset,
251-
add_inputs,
252-
)

0 commit comments

Comments
 (0)