Skip to content

sycl: use DNN in the first part of ggml_sycl_mul_mat_batched_sycl #12972

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions ggml/src/ggml-sycl/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,31 @@ class DnnlGemmWrapper {
else static_assert(0);
}

static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
// matrix A has m rows, k columns
// matrix B has k rows, n columns
// nra - number of elements to skip when moving into next row in A
// nrb - number of elements to skip when moving into next row in B
// nca - number of elements to skip when moving into next column in A
// ncb - number of elements to skip when moving into next column in B
// stride_a - number of elements to skip when moving to next A matrix
// stride_b - number of elements to skip when moving to next B matrix
// batches - number of A matrices, equal to number of B matrices
static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches) {

auto stream = ctx.stream_dnnl(q);
auto eng = ctx.engine_dnnl(q);
dnnl::memory::dims a_dims = { m, k };
dnnl::memory::dims b_dims = { k, n };
dnnl::memory::dims c_dims = { m, n };
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
dnnl::memory::dims a_dims = { batches, m, k };
dnnl::memory::dims b_dims = { batches, k, n };
dnnl::memory::dims c_dims = { batches, m, n };
dnnl::memory::dims a_strides = { stride_a, nra, nca };
dnnl::memory::dims b_strides = { stride_b, nrb, ncb };

const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);

dnnl::primitive_attr primitive_attr;
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
Expand All @@ -63,6 +78,15 @@ class DnnlGemmWrapper {

matmul_prim.execute(stream, matmul_args);
}

// matrices A and B are column major, both having k rows
// matrix A has m column, matrix B has n columns
// output: column major matrix C = A transposed * B
static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {

gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1);
}
};

#endif
Expand Down
42 changes: 23 additions & 19 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1986,7 +1986,7 @@ inline void ggml_sycl_op_mul_mat_sycl(

const int64_t ne00 = src0->ne[0];
const int64_t ne10 = src1->ne[0];

GGML_ASSERT(ne00 == ne10);

const int64_t row_diff = row_high - row_low;

Expand Down Expand Up @@ -2047,7 +2047,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
#else
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
Expand Down Expand Up @@ -2081,7 +2081,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
#else
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
#endif
Expand Down Expand Up @@ -2737,10 +2737,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32);

GGML_TENSOR_BINARY_OP_LOCALS


SYCL_CHECK(ggml_sycl_set_device(ctx.device));
queue_ptr main_stream = ctx.stream();;

Expand All @@ -2761,39 +2761,42 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
: src1_f16_alloc.get();

char * dst_t;

dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;

// dst strides
size_t nbd2 = dst->nb[2];
size_t nbd3 = dst->nb[3];
const dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
const dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;

const float alpha_f32 = 1.0f;
const float beta_f32 = 0.0f;

const void * alpha = &alpha_f32;
const void * beta = &beta_f32;

dst_t = (char *) dst_ddf;
char * dst_t = (char *) dst_ddf;

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
GGML_ASSERT(ne10 == ne00);

// broadcast factors
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
const auto r2 = ne12/ne02;
const auto r3 = ne13/ne03;
const auto ne23 = ne12*ne13;

if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
#if GGML_SYCL_DNNL
DnnlGemmWrapper::gemm(ctx, ne11, ne01, ne10,
src1_f16, DnnlGemmWrapper::to_dt<sycl::half>(), nb11/nb10, 1, nb12/nb10,
src0_as_f16, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
dst_t, DnnlGemmWrapper::to_dt<float>(), main_stream, ne23);
#else
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
*main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
(const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
(const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t,
cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type)));
cu_data_type, ne01, nb2 / nb0, ne23, cu_compute_type)));
#endif
} else {
const int ne23 = ne12*ne13;

ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
Expand Down Expand Up @@ -2821,7 +2824,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
dst_t, ptrs_src_get,
ptrs_dst_get, ne12, ne13, ne23,
nb02, nb03, nb12_scaled, nb13_scaled,
nbd2, nbd3, r2, r3, item_ct1);
nb2, nb3, r2, r3, item_ct1);
});
});
}
Expand Down Expand Up @@ -3660,7 +3663,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
return GGML_STATUS_SUCCESS;
}

sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});

model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
model_sycl_graph.end_recording();
Expand Down
4 changes: 3 additions & 1 deletion tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3865,7 +3865,7 @@ static const ggml_type other_types[] = {
// Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low
static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
std::vector<std::unique_ptr<test_case>> test_cases;
std::default_random_engine rng(0);
[[maybe_unused]] std::default_random_engine rng(0);

// unary ops
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
Expand Down Expand Up @@ -4182,6 +4182,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));

test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 1}, {1, 1}, {0, 2, 1, 3}));
}
}
for (ggml_type type_a : other_types) {
Expand Down
Loading