Skip to content

Commit

Permalink
Save functor type into typename to simplify invocation in parallel_for
Browse files Browse the repository at this point in the history
  • Loading branch information
oleksandr-pavlyk committed Jan 21, 2025
1 parent 5c412df commit 3f90e9b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 29 deletions.
20 changes: 13 additions & 7 deletions dpctl/tensor/libtensor/include/kernels/clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,22 +216,26 @@ sycl::event clip_contig_impl(sycl::queue &q,
{
constexpr bool enable_sg_loadstore = true;
using KernelName = clip_contig_kernel<T, vec_sz, n_vecs>;
using Impl =
ClipContigFunctor<T, vec_sz, n_vecs, enable_sg_loadstore>;
static_assert(sycl::is_device_copyable_v<Impl>);

cgh.parallel_for<KernelName>(
sycl::nd_range<1>(gws_range, lws_range),
ClipContigFunctor<T, vec_sz, n_vecs, enable_sg_loadstore>(
nelems, x_tp, min_tp, max_tp, dst_tp));
Impl(nelems, x_tp, min_tp, max_tp, dst_tp));
}
else {
constexpr bool disable_sg_loadstore = false;
using InnerKernelName = clip_contig_kernel<T, vec_sz, n_vecs>;
using KernelName =
disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
using Impl =
ClipContigFunctor<T, vec_sz, n_vecs, disable_sg_loadstore>;
static_assert(sycl::is_device_copyable_v<Impl>);

cgh.parallel_for<KernelName>(
sycl::nd_range<1>(gws_range, lws_range),
ClipContigFunctor<T, vec_sz, n_vecs, disable_sg_loadstore>(
nelems, x_tp, min_tp, max_tp, dst_tp));
Impl(nelems, x_tp, min_tp, max_tp, dst_tp));
}
});

Expand Down Expand Up @@ -311,10 +315,12 @@ sycl::event clip_strided_impl(sycl::queue &q,
const FourOffsets_StridedIndexer indexer{
nd, x_offset, min_offset, max_offset, dst_offset, shape_strides};

cgh.parallel_for<clip_strided_kernel<T, FourOffsets_StridedIndexer>>(
using KernelName = clip_strided_kernel<T, FourOffsets_StridedIndexer>;
using Impl = ClipStridedFunctor<T, FourOffsets_StridedIndexer>;

cgh.parallel_for<KernelName>(
sycl::range<1>(nelems),
ClipStridedFunctor<T, FourOffsets_StridedIndexer>(
x_tp, min_tp, max_tp, dst_tp, indexer));
Impl(x_tp, min_tp, max_tp, dst_tp, indexer));
});

return clip_ev;
Expand Down
54 changes: 32 additions & 22 deletions dpctl/tensor/libtensor/include/kernels/constructors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
//===----------------------------------------------------------------------===//

#pragma once
#include <complex>
#include <cstddef>

#include <sycl/sycl.hpp>

#include "dpctl_tensor_types.hpp"
#include "utils/offset_utils.hpp"
#include "utils/strided_iters.hpp"
#include "utils/type_utils.hpp"
#include <complex>
#include <cstddef>
#include <sycl/sycl.hpp>

namespace dpctl
{
Expand Down Expand Up @@ -200,22 +202,25 @@ sycl::event lin_space_affine_impl(sycl::queue &exec_q,
{
dpctl::tensor::type_utils::validate_type_for_device<Ty>(exec_q);

bool device_supports_doubles = exec_q.get_device().has(sycl::aspect::fp64);
const bool device_supports_doubles =
exec_q.get_device().has(sycl::aspect::fp64);
const std::size_t den = (include_endpoint) ? nelems - 1 : nelems;

sycl::event lin_space_affine_event = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
if (device_supports_doubles) {
cgh.parallel_for<linear_sequence_affine_kernel<Ty, double>>(
sycl::range<1>{nelems},
LinearSequenceAffineFunctor<Ty, double>(
array_data, start_v, end_v,
(include_endpoint) ? nelems - 1 : nelems));
using KernelName = linear_sequence_affine_kernel<Ty, double>;
using Impl = LinearSequenceAffineFunctor<Ty, double>;

cgh.parallel_for<KernelName>(sycl::range<1>{nelems},
Impl(array_data, start_v, end_v, den));
}
else {
cgh.parallel_for<linear_sequence_affine_kernel<Ty, float>>(
sycl::range<1>{nelems},
LinearSequenceAffineFunctor<Ty, float>(
array_data, start_v, end_v,
(include_endpoint) ? nelems - 1 : nelems));
using KernelName = linear_sequence_affine_kernel<Ty, float>;
using Impl = LinearSequenceAffineFunctor<Ty, float>;

cgh.parallel_for<KernelName>(sycl::range<1>{nelems},
Impl(array_data, start_v, end_v, den));
}
});

Expand Down Expand Up @@ -312,10 +317,12 @@ sycl::event full_strided_impl(sycl::queue &q,

sycl::event fill_ev = q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
cgh.parallel_for<full_strided_kernel<dstTy>>(
sycl::range<1>{nelems},
FullStridedFunctor<dstTy, decltype(strided_indexer)>(
dst_tp, fill_v, strided_indexer));

using KernelName = full_strided_kernel<dstTy>;
using Impl = FullStridedFunctor<dstTy, StridedIndexer>;

cgh.parallel_for<KernelName>(sycl::range<1>{nelems},
Impl(dst_tp, fill_v, strided_indexer));
});

return fill_ev;
Expand Down Expand Up @@ -388,9 +395,12 @@ sycl::event eye_impl(sycl::queue &exec_q,
dpctl::tensor::type_utils::validate_type_for_device<Ty>(exec_q);
sycl::event eye_event = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
cgh.parallel_for<eye_kernel<Ty>>(
sycl::range<1>{nelems},
EyeFunctor<Ty>(array_data, start, end, step));

using KernelName = eye_kernel<Ty>;
using Impl = EyeFunctor<Ty>;

cgh.parallel_for<KernelName>(sycl::range<1>{nelems},
Impl(array_data, start, end, step));
});

return eye_event;
Expand Down Expand Up @@ -478,7 +488,7 @@ sycl::event tri_impl(sycl::queue &exec_q,
ssize_t inner_gid = idx[0] - inner_range * outer_gid;

ssize_t src_inner_offset = 0, dst_inner_offset = 0;
bool to_copy(true);
bool to_copy{false};

{
using dpctl::tensor::strides::CIndexer_array;
Expand Down

0 comments on commit 3f90e9b

Please sign in to comment.