diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index 96552bd8df1..2324bfacd22 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -179,51 +179,6 @@ static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib, #endif // GGML_SYCL_F16 } -static __dpct_inline__ void dequantize_q8_0_reorder(const void *d_ptr, const int64_t ib, const void *qs, - const int iqs, dfloat2 &v) { - const dfloat d = (const dfloat)*((const sycl::half*)d_ptr+ib); - - const int8_t * qs_ptr = (const int8_t *)qs; - - v.x() = qs_ptr[iqs + 0]; - v.y() = qs_ptr[iqs + 1]; - -#ifdef GGML_SYCL_F16 - v.s0() *= d; - v.s1() *= d; -#else - v.x() *= d; - v.y() *= d; -#endif // GGML_SYCL_F16 -} - -template -static void dequantize_block_q8_0_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t k, - const sycl::nd_item<3> &item_ct1) { - - const int64_t i = item_ct1.get_group(2); - - // assume 32 threads - const int64_t tid = item_ct1.get_local_id(2); - const int64_t lane_ib = i * WARP_SIZE + tid; - - if (lane_ib >= k / QK8_0) { - return; - } - - dst_t * y_ptr = yy + lane_ib * QK8_0; - - auto qs = (const int8_t*)vx + lane_ib * QK8_0; - auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k) + lane_ib; - - const float d = float(*s_ptr); - -#pragma unroll - for (int l = 0; l < QK8_0; ++l) { - y_ptr[l] = d * qs[l]; - } -} - template static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, const sycl::nd_item<3> &item_ct1) { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index b8634fd4ac5..93d048c58a4 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -468,8 +468,8 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer, auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue()); SYCL_CHECK(CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw())); #ifndef _WIN32 - // Copy mmap'd data through a host buffer to avoid Level Zero OOM when - // pinning file-backed pages for direct DMA (affects PVC and Battlemage). + // Note: Use host buffer to save the data from mmap(), then copy to device. It's workaround for mmap() issue on PVC GPU. + // This function will be called during load model from disk. Use memory buffer replace dynamic won't save more time and brings potential memory leak risk here. char * host_buf = (char *) malloc(size); memcpy(host_buf, data, size); SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, host_buf, size).wait())); @@ -3399,20 +3399,6 @@ static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) { sycl::free(ptr, *stream); } -static void * reorder_scratch_buf = nullptr; -static size_t reorder_scratch_size = 0; - -static void * reorder_get_scratch(dpct::queue_ptr stream, size_t size) { - if (size > reorder_scratch_size) { - if (reorder_scratch_buf) { - sycl_ext_free(stream, reorder_scratch_buf); - } - reorder_scratch_buf = sycl_ext_malloc_device(stream, size); - reorder_scratch_size = size; - } - return reorder_scratch_buf; -} - // RAII wrapper for temporary reorder buffers with optional host memory fallback. // When device allocation fails and GGML_SYCL_HOST_MEM_FALLBACK is enabled, // falls back to host memory so the reorder kernel can still run (over PCIe). @@ -3456,7 +3442,12 @@ struct sycl_reorder_temp_buffer { static bool reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, dpct::queue_ptr stream) { - uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3485,16 +3476,17 @@ static bool reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); + return true; } -static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { - GGML_ASSERT(size % sizeof(block_q4_K) == 0); - GGML_ASSERT(offset % sizeof(block_q4_K) == 0); - - const int nblocks = size / sizeof(block_q4_K); - - uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); +static bool reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, + dpct::queue_ptr stream) { + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3502,37 +3494,42 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d copy_event.wait(); } - auto * qs_ptr = data_device; - auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks; - auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks); - - auto reorder_event = stream->parallel_for(nblocks, [=](auto i) { - const block_q4_K * x = (const block_q4_K *) tmp_buf; - const int ib = i; - - for (int j = 0; j < QK_K / 2; ++j) { - qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j]; - } + GGML_ASSERT((size % sizeof(block_q8_0) == 0)); + GGML_ASSERT((offset % sizeof(block_q8_0) == 0)); + int offset_blks = offset / sizeof(block_q8_0); + auto qs_ptr = data_device + offset_blks * QK8_0; + auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows) + offset_blks; - for (int j = 0; j < K_SCALE_SIZE; ++j) { - scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j]; - } + auto reorder_event = stream->parallel_for( + size / sizeof(block_q8_0), + [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + const block_q8_0* x = (const block_q8_0*)tmp_buf; + const int ib = i; - dm_ptr[ib] = x[ib].dm; - }); + for (int j = 0; j < QK8_0; j++) + { + *((int8_t*)qs_ptr + ib * QK8_0 + j) = x[ib].qs[j]; + } + *(d_ptr + ib) = x[ib].d; + }); if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); + return true; } -static bool reorder_qw_q5_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { - GGML_ASSERT(size % sizeof(block_q5_K) == 0); - GGML_ASSERT(offset % sizeof(block_q5_K) == 0); +static bool reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { + GGML_ASSERT(size % sizeof(block_q4_K) == 0); + GGML_ASSERT(offset % sizeof(block_q4_K) == 0); - const int nblocks = size / sizeof(block_q5_K); + const int nblocks = size / sizeof(block_q4_K); - uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3541,22 +3538,17 @@ static bool reorder_qw_q5_k(uint8_t * data_device, size_t size, size_t offset, d } auto * qs_ptr = data_device; - auto * qh_ptr = qs_ptr + (QK_K / 2) * nblocks; - auto * scales_ptr = qh_ptr + (QK_K / 8) * nblocks; + auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks; auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks); auto reorder_event = stream->parallel_for(nblocks, [=](auto i) { - const block_q5_K * x = (const block_q5_K *) tmp_buf; + const block_q4_K * x = (const block_q4_K *) tmp_buf; const int ib = i; for (int j = 0; j < QK_K / 2; ++j) { qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j]; } - for (int j = 0; j < QK_K / 8; ++j) { - qh_ptr[ib * (QK_K / 8) + j] = x[ib].qh[j]; - } - for (int j = 0; j < K_SCALE_SIZE; ++j) { scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j]; } @@ -3566,7 +3558,7 @@ static bool reorder_qw_q5_k(uint8_t * data_device, size_t size, size_t offset, d if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); + return true; } static bool reorder_qw_q5_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { @@ -3575,7 +3567,12 @@ static bool reorder_qw_q5_k(uint8_t * data_device, size_t size, size_t offset, d const int nblocks = size / sizeof(block_q5_K); - uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3609,7 +3606,6 @@ static bool reorder_qw_q5_k(uint8_t * data_device, size_t size, size_t offset, d if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); return true; } @@ -3619,7 +3615,12 @@ static bool reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d const int nblocks = size / sizeof(block_q6_K); - uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3658,42 +3659,10 @@ static bool reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); -} - -static void reorder_qw_q8_0(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { - GGML_ASSERT(size % sizeof(block_q8_0) == 0); - GGML_ASSERT(offset % sizeof(block_q8_0) == 0); - - const int nblocks = size / sizeof(block_q8_0); - - uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); - - sycl::event copy_event; - SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); - if (!g_ggml_sycl_use_async_mem_op) { - copy_event.wait(); - } - - auto * qs_ptr = data_device; - auto * d_ptr = reinterpret_cast(qs_ptr + QK8_0 * nblocks); - - auto reorder_event = stream->parallel_for(nblocks, [=](auto i) { - const block_q8_0 * x = (const block_q8_0*) tmp_buf; - const int ib = i; - for (int j = 0; j < QK8_0; ++j) { - qs_ptr[ib * QK8_0 + j] = reinterpret_cast(x[ib].qs)[j]; - } - d_ptr[ib] = x[ib].d; - }); - - if (!g_ggml_sycl_use_async_mem_op) { - reorder_event.wait_and_throw(); - } - sycl_ext_free(stream, tmp_buf); + return true; } -static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { +static bool reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { uint8_t * data_device = (uint8_t *) src0->data; size_t ncols = src0->ne[0]; size_t nrows = src0->ne[1]; @@ -3703,7 +3672,7 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { case GGML_TYPE_Q4_0: return reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); case GGML_TYPE_Q8_0: - return reorder_qw_q8_0(data_device, size, 0, stream); + return reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream); case GGML_TYPE_Q4_K: return reorder_qw_q4_k(data_device, size, 0, stream); case GGML_TYPE_Q5_K: @@ -4608,6 +4577,55 @@ catch (sycl::exception const &exc) { std::exit(1); } +static bool ggml_sycl_can_fuse(ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) { + if (!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + + if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + const ggml_tensor * rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor * mul = cgraph->nodes[node_idx + 1]; + const ggml_tensor * add = nullptr; + + if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) { + add = cgraph->nodes[node_idx + 2]; + } + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + + if (mul->src[0]->type != GGML_TYPE_F32 || + mul->src[1]->type != GGML_TYPE_F32 || + mul->type != GGML_TYPE_F32) { + return false; + } + + if (add && (add->src[0]->type != GGML_TYPE_F32 || + add->src[1]->type != GGML_TYPE_F32 || + add->type != GGML_TYPE_F32)) { + return false; + } + + // If rms_norm is the B operand, this fusion path does not support expansion of the A operand. + if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) { + return false; + } + + // rms_norm kernel assumes contiguous rows. + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + + if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) { + return false; + } + + return true; + } + + return true; +} + static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) { ggml_sycl_set_main_device(sycl_ctx->device); @@ -4621,14 +4639,14 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc } if (node->op == GGML_OP_RMS_NORM && - ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD })) { + ggml_sycl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD })) { ggml_sycl_op_rms_norm_fused_add(*sycl_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]); i += 2; continue; } if (node->op == GGML_OP_RMS_NORM && - ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ggml_sycl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ggml_sycl_op_rms_norm_fused(*sycl_ctx, node, cgraph->nodes[i + 1]); i++; continue; diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 45f92e7c473..97b40f8e460 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -722,26 +722,6 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, } } -static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, - const int nrows, dpct::queue_ptr stream) { - GGML_ASSERT(ncols % QK8_0 == 0); - - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); - - const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); - const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); - - stream->submit([&](sycl::handler & cgh) { - cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), - [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q_reorder>(vx, vy, dst, ncols, - nrows, nd_item); - }); - }); -} - static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index 880029d7476..270a9b6262f 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -207,8 +207,12 @@ static void rms_norm_mul_f32(const float * x, const float * mul, const float * a const int64_t stride_sample, const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample, + const int64_t mul_ncols, const int64_t mul_nrows, + const int64_t mul_nchannels, const int64_t mul_nsamples, const int64_t add_stride_row, const int64_t add_stride_channel, const int64_t add_stride_sample, + const int64_t add_ncols, const int64_t add_nrows, + const int64_t add_nchannels, const int64_t add_nsamples, const float eps, const sycl::nd_item<3> & item_ct1, float * s_sum, int block_size) { const int nrows = item_ct1.get_group_range(2); @@ -226,7 +230,10 @@ static void rms_norm_mul_f32(const float * x, const float * mul, const float * a x += strided_offset; dst += packed_offset; - const auto mul_offset = calculate_offset<3>({mul_stride_sample, mul_stride_channel, mul_stride_row}, {sample, channel, row}); + const auto mul_row = row % mul_nrows; + const auto mul_channel = channel % mul_nchannels; + const auto mul_sample = sample % mul_nsamples; + const auto mul_offset = calculate_offset<3>({mul_stride_sample, mul_stride_channel, mul_stride_row}, {mul_sample, mul_channel, mul_row}); mul += mul_offset; float tmp = 0.0f; @@ -256,13 +263,20 @@ static void rms_norm_mul_f32(const float * x, const float * mul, const float * a const float scale = sycl::rsqrt(mean + eps); if (add) { - const auto add_off = calculate_offset<3>({add_stride_sample, add_stride_channel, add_stride_row}, {sample, channel, row}); + const auto add_row = row % add_nrows; + const auto add_channel = channel % add_nchannels; + const auto add_sample = sample % add_nsamples; + const auto add_off = calculate_offset<3>({add_stride_sample, add_stride_channel, add_stride_row}, {add_sample, add_channel, add_row}); + add += add_off; for (int col = tid; col < ncols; col += block_size) { - dst[col] = scale * x[col] * mul[col] + (add + add_off)[col]; + const auto mul_col = col % mul_ncols; + const auto add_col = col % add_ncols; + dst[col] = scale * x[col] * mul[mul_col] + add[add_col]; } } else { for (int col = tid; col < ncols; col += block_size) { - dst[col] = scale * x[col] * mul[col]; + const auto mul_col = col % mul_ncols; + dst[col] = scale * x[col] * mul[mul_col]; } } } @@ -425,7 +439,9 @@ static void rms_norm_mul_f32_sycl(const float * x, const float * mul, const floa const int ncols, const int nrows, const int nchannels, const int nsamples, const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample, + const int64_t mul_ncols, const int64_t mul_nrows, const int64_t mul_nchannels, const int64_t mul_nsamples, const int64_t add_stride_row, const int64_t add_stride_channel, const int64_t add_stride_sample, + const int64_t add_ncols, const int64_t add_nrows, const int64_t add_nchannels, const int64_t add_nsamples, const float eps, queue_ptr stream, int device) { const sycl::range<3> global_dims(nsamples, nchannels, nrows); @@ -437,7 +453,9 @@ static void rms_norm_mul_f32_sycl(const float * x, const float * mul, const floa [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { rms_norm_mul_f32(x, mul, add, dst, ncols, stride_row, stride_channel, stride_sample, mul_stride_row, mul_stride_channel, mul_stride_sample, + mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, add_stride_row, add_stride_channel, add_stride_sample, + add_ncols, add_nrows, add_nchannels, add_nsamples, eps, item_ct1, nullptr, WARP_SIZE); }); }); @@ -452,7 +470,9 @@ static void rms_norm_mul_f32_sycl(const float * x, const float * mul, const floa [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { rms_norm_mul_f32(x, mul, add, dst, ncols, stride_row, stride_channel, stride_sample, mul_stride_row, mul_stride_channel, mul_stride_sample, + mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, add_stride_row, add_stride_channel, add_stride_sample, + add_ncols, add_nrows, add_nchannels, add_nsamples, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); }); }); @@ -621,11 +641,18 @@ void ggml_sycl_op_rms_norm_fused(ggml_backend_sycl_context & ctx, ggml_tensor * const int64_t mul_s02 = mul_src->nb[2] / ts_mul; const int64_t mul_s03 = mul_src->nb[3] / ts_mul; + const int64_t mul_ne00 = mul_src->ne[0]; + const int64_t mul_ne01 = mul_src->ne[1]; + const int64_t mul_ne02 = mul_src->ne[2]; + const int64_t mul_ne03 = mul_src->ne[3]; + rms_norm_mul_f32_sycl(src0_dd, mul_dd, nullptr, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, + mul_ne00, mul_ne01, mul_ne02, mul_ne03, 0, 0, 0, + 0, 0, 0, 0, eps, main_stream, ctx.device); } @@ -684,17 +711,29 @@ void ggml_sycl_op_rms_norm_fused_add(ggml_backend_sycl_context & ctx, const int64_t mul_s02 = mul_src->nb[2] / ts_mul; const int64_t mul_s03 = mul_src->nb[3] / ts_mul; + const int64_t mul_ne00 = mul_src->ne[0]; + const int64_t mul_ne01 = mul_src->ne[1]; + const int64_t mul_ne02 = mul_src->ne[2]; + const int64_t mul_ne03 = mul_src->ne[3]; + const size_t ts_add = ggml_type_size(add_src->type); GGML_ASSERT(add_src->nb[0] == ts_add); const int64_t add_s01 = add_src->nb[1] / ts_add; const int64_t add_s02 = add_src->nb[2] / ts_add; const int64_t add_s03 = add_src->nb[3] / ts_add; + const int64_t add_ne00 = add_src->ne[0]; + const int64_t add_ne01 = add_src->ne[1]; + const int64_t add_ne02 = add_src->ne[2]; + const int64_t add_ne03 = add_src->ne[3]; + rms_norm_mul_f32_sycl(src0_dd, mul_dd, add_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, + mul_ne00, mul_ne01, mul_ne02, mul_ne03, add_s01, add_s02, add_s03, + add_ne00, add_ne01, add_ne02, add_ne03, eps, main_stream, ctx.device); }