Skip to content

Commit eb77574

Browse files
authored
add onednn w8a16 gemm (#24)
* add onednn w8a16 gemm --------- Signed-off-by: Zhu, Zufang <[email protected]>
1 parent 8a7400a commit eb77574

File tree

14 files changed

+1358
-5
lines changed

14 files changed

+1358
-5
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "third_party/oneDNN"]
2+
path = third_party/oneDNN
3+
url = https://github.com/uxlfoundation/oneDNN.git

CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
2020

2121
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
2222

23+
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)
24+
2325
# Suppress potential warnings about unused manually-specified variables
2426
set(ignoreMe "${VLLM_PYTHON_PATH}")
2527

@@ -66,6 +68,7 @@ append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
6668
# Import torch cmake configuration.
6769
find_package(Torch REQUIRED)
6870

71+
find_package(oneDNN REQUIRED)
6972

7073
#
7174
# Forward the non-CUDA device extensions to external CMake scripts.
@@ -191,8 +194,10 @@ if(VLLM_GPU_LANG STREQUAL "SYCL")
191194
endif()
192195

193196
if(ONEDNN_FOUND)
197+
set(_ONEDNN_SRC)
198+
file(GLOB _ONEDNN_SRC csrc/xpu/onednn/*.cpp)
194199
list(APPEND VLLM_EXT_XPU_SRC
195-
"csrc/xpu/onednn/*.cpp"
200+
${_ONEDNN_SRC}
196201
)
197202
include_directories(${ONEDNN_INCLUDE_DIR})
198203
link_libraries(${ONEDNN_LIBRARY})

README.md

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ python3 setup.py develop - will be local install if we use develop build or syst
1212

1313
On vllm side, we will `import vllm_xpu_kernels._C` at start time which should register all custom ops so we can directly use.
1414

15-
### prepare
15+
### Prepare
1616

1717
Install oneapi 2025.1 deep learning essential [dependency](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html).
1818

@@ -22,7 +22,7 @@ Create a new virtual env, install build dependency and torch dependency
2222
pip install -r requirements.txt
2323
```
2424

25-
### build & install
25+
### Build & Install
2626
Build development installation to current directory:
2727

2828
```
@@ -41,5 +41,25 @@ or build wheel (generated .whl in dist folder)
4141
VLLM_TARGET_DEVICE=xpu python3 setup.py bdist_wheel
4242
```
4343

44-
### how to use in vLLM
44+
### How to use in vLLM
4545
Please refer to temporary branch https://github.com/jikunshang/vllm/tree/xpu_kernel to install & test vllm which replaces `rms_norm` kernel from IPEX to vllm-xpu-kernels.
46+
47+
### Why Static Linking DNNL Instead of Shared Linking?
48+
49+
We chose to **statically link oneDNN (DNNL)** rather than using it as a shared library for the following reasons:
50+
51+
#### 1. **Version Compatibility**
52+
53+
Static linking ensures our application always uses the exact version of DNNL. With shared libraries, there's a risk that system-installed versions might be incompatible or introduce subtle bugs due to API/ABI changes.
54+
55+
#### 2. **Performance Consistency**
56+
57+
By linking statically, we avoid potential performance variability introduced by different builds or configurations of DNNL that might be present on the host system.
58+
59+
#### 3. **Avoiding Runtime Errors**
60+
61+
Using shared libraries requires correct paths and environment setup (`LD_LIBRARY_PATH` on Linux). Static linking avoids issues where DNNL cannot be found or loaded at runtime.
62+
63+
#### 4. **Aligning with PyTorch**
64+
65+
One key reason to use static linking is to maintain consistency with the PyTorch ecosystem. PyTorch itself statically links libraries like DNNL to ensure deterministic and reliable behavior across different environments.

cmake/Modules/FindoneDNN.cmake

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# - Try to find oneDNN
2+
#
3+
# The following are set after configuration is done:
4+
# ONEDNN_FOUND : set to true if oneDNN is found.
5+
# ONEDNN_INCLUDE_DIR : path to oneDNN include dir.
6+
# ONEDNN_LIBRARY : list of libraries for oneDNN
7+
#
8+
9+
IF (NOT ONEDNN_FOUND)
10+
SET(ONEDNN_FOUND OFF)
11+
12+
SET(ONEDNN_LIBRARY)
13+
SET(ONEDNN_INCLUDE_DIR)
14+
SET(DNNL_INCLUDES)
15+
16+
SET(THIRD_PARTY_DIR "${PROJECT_SOURCE_DIR}/third_party")
17+
SET(ONEDNN_DIR "oneDNN")
18+
SET(ONEDNN_ROOT "${THIRD_PARTY_DIR}/${ONEDNN_DIR}")
19+
20+
FIND_PATH(ONEDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${ONEDNN_ROOT} PATH_SUFFIXES include NO_DEFAULT_PATH)
21+
IF(NOT ONEDNN_INCLUDE_DIR)
22+
FIND_PACKAGE(Git)
23+
IF(NOT Git_FOUND)
24+
MESSAGE(FATAL_ERROR "Can not find Git executable!")
25+
ENDIF()
26+
EXECUTE_PROCESS(
27+
COMMAND ${GIT_EXECUTABLE} submodule update --init ${ONEDNN_DIR}
28+
WORKING_DIRECTORY ${THIRD_PARTY_DIR} COMMAND_ERROR_IS_FATAL ANY)
29+
FIND_PATH(ONEDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${ONEDNN_ROOT} PATH_SUFFIXES include NO_DEFAULT_PATH)
30+
ENDIF(NOT ONEDNN_INCLUDE_DIR)
31+
32+
IF(NOT ONEDNN_INCLUDE_DIR)
33+
MESSAGE(FATAL_ERROR "oneDNN source files not found!")
34+
ENDIF(NOT ONEDNN_INCLUDE_DIR)
35+
36+
SET(DNNL_ENABLE_PRIMITIVE_CACHE TRUE CACHE BOOL "oneDNN sycl primitive cache" FORCE)
37+
38+
SET(DNNL_LIBRARY_TYPE STATIC CACHE STRING "" FORCE)
39+
40+
SET(DNNL_CPU_RUNTIME "THREADPOOL" CACHE STRING "oneDNN cpu backend" FORCE)
41+
SET(DNNL_GPU_RUNTIME "SYCL" CACHE STRING "oneDNN gpu backend" FORCE)
42+
SET(DNNL_BUILD_TESTS FALSE CACHE BOOL "build with oneDNN tests" FORCE)
43+
SET(DNNL_BUILD_EXAMPLES FALSE CACHE BOOL "build with oneDNN examples" FORCE)
44+
SET(DNNL_ENABLE_CONCURRENT_EXEC TRUE CACHE BOOL "multi-thread primitive execution" FORCE)
45+
SET(DNNL_EXPERIMENTAL TRUE CACHE BOOL "use one pass for oneDNN BatchNorm" FORCE)
46+
47+
ADD_SUBDIRECTORY(${ONEDNN_ROOT} oneDNN EXCLUDE_FROM_ALL)
48+
SET(ONEDNN_LIBRARY ${DNNL_LIBRARY_NAME})
49+
IF(NOT TARGET ${ONEDNN_LIBRARY})
50+
MESSAGE(FATAL_ERROR "Failed to include oneDNN target")
51+
ENDIF(NOT TARGET ${ONEDNN_LIBRARY})
52+
53+
IF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC)
54+
TARGET_COMPILE_OPTIONS(${ONEDNN_LIBRARY} PRIVATE -Wno-uninitialized)
55+
TARGET_COMPILE_OPTIONS(${ONEDNN_LIBRARY} PRIVATE -Wno-strict-overflow)
56+
TARGET_COMPILE_OPTIONS(${ONEDNN_LIBRARY} PRIVATE -Wno-error=strict-overflow)
57+
ENDIF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC)
58+
59+
TARGET_COMPILE_OPTIONS(${ONEDNN_LIBRARY} PRIVATE -Wno-tautological-compare)
60+
GET_TARGET_PROPERTY(DNNL_INCLUDES ${ONEDNN_LIBRARY} INCLUDE_DIRECTORIES)
61+
TARGET_LINK_LIBRARIES(${ONEDNN_LIBRARY} PRIVATE ze_loader)
62+
list(APPEND ONEDNN_INCLUDE_DIR ${DNNL_INCLUDES})
63+
64+
# Upper level targets should not load header files from oneDNN's third party.
65+
LIST(FILTER ONEDNN_INCLUDE_DIR EXCLUDE REGEX
66+
".*third_party/oneDNN/third_party.*")
67+
68+
SET(ONEDNN_FOUND ON)
69+
MESSAGE(STATUS "Found oneDNN: TRUE")
70+
71+
ENDIF(NOT ONEDNN_FOUND)

csrc/xpu/onednn/fp8_gemm_w8a16.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#include <vector>
2+
#include "fp8_gemm_w8a16.h"
3+
4+
torch::Tensor fp8_gemm_w8a16(const torch::Tensor& A, const torch::Tensor& B,
5+
bool trans_B,
6+
const std::optional<torch::Tensor>& B_scale_,
7+
const std::optional<torch::Tensor>& bias_) {
8+
TORCH_CHECK(A.dim() == 2 || A.dim() == 3,
9+
"fp8_gemm_w8a16 only support 2D and 3D inputs!\n");
10+
TORCH_CHECK(B.dim() == 2, "fp8_gemm_w8a16 only support 2D weights!\n");
11+
12+
std::vector<int64_t> result_shape;
13+
if (A.dim() == 2) {
14+
if (trans_B) {
15+
result_shape = {A.size(0), B.size(0)};
16+
} else {
17+
result_shape = {A.size(0), B.size(1)};
18+
}
19+
// src{m, k}, wei{k, n}, bias{n}, dst{m, n}
20+
} else {
21+
if (trans_B) {
22+
result_shape = {A.size(0), A.size(1), B.size(0)};
23+
} else {
24+
result_shape = {A.size(0), A.size(1), B.size(1)};
25+
}
26+
// src{b, m, k}, wei{k, n}, bias{n}, dst{b, m, n}
27+
}
28+
29+
// deal with input shape [m, b, k] stride [k, m * k, 1]
30+
auto k = A.size(A.dim() - 1);
31+
auto n = result_shape.back();
32+
auto res_stride = A.strides().vec();
33+
for (int i = 0; i < res_stride.size() - 1; i++) {
34+
res_stride[i] = res_stride[i] / k * n;
35+
}
36+
37+
torch::Tensor result =
38+
at::empty_strided(result_shape, res_stride, A.options());
39+
40+
// check if nt format
41+
bool is_nt = true;
42+
if (trans_B) {
43+
is_nt = B.strides()[B.dim() - 1] == 1;
44+
} else {
45+
is_nt = B.strides()[B.dim() - 2] == 1;
46+
}
47+
48+
torch::Tensor B_scale = B_scale_.has_value()
49+
? B_scale_.value()
50+
: at::ones({1}, B.options().dtype(A.dtype()));
51+
52+
oneDNN::dnnl_matmul_w8a16_fp8(result, A, B, is_nt, bias_, B_scale);
53+
return result;
54+
}

csrc/xpu/onednn/fp8_gemm_w8a16.h

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
#pragma once
2+
3+
#include <c10/xpu/XPUStream.h>
4+
#include <dnnl.hpp>
5+
#include <torch/torch.h>
6+
7+
#include "onednn_ext.h"
8+
9+
namespace oneDNN {
10+
11+
using bias_type_t = at::native::onednn::bias_type_t;
12+
using trans_type_t = at::native::onednn::trans_type_t;
13+
using GpuStreamManager = at::native::onednn::GpuStreamManager;
14+
using GpuEngineManager = at::native::onednn::GpuEngineManager;
15+
16+
static inline void dnnl_matmul_w8a16_fp8(
17+
torch::Tensor& result, const torch::Tensor& mat1, const torch::Tensor& mat2,
18+
bool trans_b, const std::optional<torch::Tensor>& bias,
19+
const torch::Tensor& m2_sc, const int64_t group_size = 0) {
20+
TORCH_CHECK(mat2.scalar_type() == at::ScalarType::Float8_e5m2 ||
21+
mat2.scalar_type() == at::ScalarType::Float8_e4m3fn,
22+
"weight must be f8_e5m2 or f8_e4m3fn for fp8 matmul");
23+
auto src_sz = mat1.sizes();
24+
auto o_sz = result.sizes();
25+
26+
const int m = std::reduce(src_sz.begin(), src_sz.end() - 1, 1,
27+
std::multiplies<int64_t>());
28+
const int n = o_sz.back(); // presume channel last format
29+
const int k = *(src_sz.end() - 1);
30+
31+
// get joint dtypes
32+
joint_dtypes_t jd;
33+
auto in_dtype = mat1.scalar_type();
34+
auto wei_dtype = mat2.scalar_type();
35+
if (in_dtype == at::ScalarType::Half) {
36+
jd = wei_dtype == at::ScalarType::Float8_e5m2 ? joint_dtypes_t::f16_f8_e5m2
37+
: joint_dtypes_t::f16_f8_e4m3;
38+
} else if (in_dtype == at::ScalarType::BFloat16) {
39+
jd = wei_dtype == at::ScalarType::Float8_e5m2
40+
? joint_dtypes_t::bf16_f8_e5m2
41+
: joint_dtypes_t::bf16_f8_e4m3;
42+
} else {
43+
TORCH_INTERNAL_ASSERT(
44+
false, "Unsupported data type for fp8 matmul: ", mat1.scalar_type());
45+
}
46+
47+
// get bias type
48+
bias_type_t b_type;
49+
if (bias.has_value() && bias.value().defined()) {
50+
auto& b = bias.value();
51+
const auto nuelm = b.numel();
52+
if (nuelm == 1) {
53+
b_type = bias_type_t::scalar;
54+
} else if (nuelm == m * n) {
55+
b_type = bias_type_t::mn;
56+
} else if (b.size(b.dim() - 1) == n && nuelm == n) {
57+
b_type = bias_type_t::n;
58+
} else if (b.size(b.dim() - 1) == 1 && nuelm == m) {
59+
b_type = bias_type_t::m;
60+
} else if (nuelm == 0) {
61+
b_type = bias_type_t::none;
62+
} else {
63+
TORCH_CHECK(0, "unsupported bias dim in matmul ...", b.sizes());
64+
}
65+
} else {
66+
b_type = bias_type_t::none;
67+
}
68+
69+
trans_type_t tt = trans_type_t::nn;
70+
if (trans_b) {
71+
// transpose mat2
72+
tt = trans_type_t::nt;
73+
}
74+
75+
// get lda ldb and ldc
76+
auto mat1_strides = mat1.strides();
77+
int64_t leading_dim = -1;
78+
if (mat1.dim() == 2) {
79+
leading_dim = 0;
80+
} else if (mat1.dim() == 3) {
81+
leading_dim = mat1_strides[0] < mat1_strides[1] ? 0 : 1;
82+
} else {
83+
TORCH_CHECK(false,
84+
"Unsupported input dimension for fp8 matmul: ", mat1.dim());
85+
}
86+
int64_t lda = mat1_strides[leading_dim];
87+
int64_t ldb = mat2.strides()[mat2.dim() - 1] == 1
88+
? mat2.strides()[mat2.dim() - 2]
89+
: mat2.strides()[mat2.dim() - 1];
90+
int64_t ldc = result.strides()[leading_dim];
91+
92+
auto f_attr = [&](dnnl::primitive_attr& pattr) {
93+
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
94+
};
95+
96+
int arg_off = 0;
97+
98+
// ************************************************************
99+
// get device, engine, stream
100+
const int dev_id = c10::xpu::getCurrentXPUStream().device_index();
101+
at::Device curDevice = at::Device(at::kXPU, dev_id);
102+
auto engine = GpuEngineManager::Instance().get_engine(curDevice);
103+
104+
auto& matmul_ext = matmul_primitive_create_and_cache(
105+
jd, tt, b_type, m, n, k, lda, ldb, ldc, dev_id, f_attr, group_size);
106+
107+
matmul_ext.set_attribute(arg_off++, DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS,
108+
m2_sc.data_ptr(), [&]() {
109+
return at::native::onednn::make_onednn_memory(
110+
get_onednn_md(m2_sc), engine,
111+
m2_sc.data_ptr());
112+
});
113+
114+
std::vector<std::pair<int, void*>> arg_handles;
115+
arg_handles.reserve(8);
116+
117+
arg_handles.emplace_back(DNNL_ARG_SRC, mat1.data_ptr());
118+
arg_handles.emplace_back(DNNL_ARG_WEIGHTS, mat2.data_ptr());
119+
arg_handles.emplace_back(DNNL_ARG_DST, result.data_ptr());
120+
if (b_type != bias_type_t::none) {
121+
arg_handles.emplace_back(DNNL_ARG_BIAS, bias.value().data_ptr());
122+
}
123+
124+
int scratchpad_size = matmul_ext.get_scratchpad_size();
125+
torch::Tensor scratchpad_tensor = at::empty(
126+
{scratchpad_size}, mat1.options().dtype(at::kByte), c10::nullopt);
127+
arg_handles.emplace_back(DNNL_ARG_SCRATCHPAD, scratchpad_tensor.data_ptr());
128+
129+
auto& strm = GpuStreamManager::Instance().get_stream();
130+
auto qfp8_matmul_event =
131+
matmul_ext.execute(strm, engine, std::move(arg_handles), arg_off);
132+
}
133+
} // namespace oneDNN

0 commit comments

Comments
 (0)