-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathfill.hpp
98 lines (81 loc) · 3.11 KB
/
fill.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
// SPDX-FileCopyrightText: Intel Corporation
//
// SPDX-License-Identifier: BSD-3-Clause
#pragma once
#include <memory>
#include <type_traits>
#include <sycl/sycl.hpp>
#include <dr/concepts/concepts.hpp>
#include <dr/detail/segments_tools.hpp>
#include <dr/shp/detail.hpp>
#include <dr/shp/device_ptr.hpp>
#include <dr/shp/util.hpp>
namespace dr::shp {
template <std::contiguous_iterator Iter>
requires(!std::is_const_v<std::iter_value_t<Iter>> &&
std::is_trivially_copyable_v<std::iter_value_t<Iter>>)
sycl::event fill_async(Iter first, Iter last,
const std::iter_value_t<Iter> &value) {
auto &&q = __detail::get_queue_for_pointer(first);
std::iter_value_t<Iter> *arr = std::to_address(first);
// not using q.fill because of CMPLRLLVM-46438
return dr::__detail::parallel_for(q, sycl::range<>(last - first),
[=](auto idx) { arr[idx] = value; });
}
template <std::contiguous_iterator Iter>
requires(!std::is_const_v<std::iter_value_t<Iter>>)
void fill(Iter first, Iter last, const std::iter_value_t<Iter> &value) {
fill_async(first, last, value).wait();
}
template <typename T, typename U>
requires(std::indirectly_writable<device_ptr<T>, U>)
sycl::event fill_async(device_ptr<T> first, device_ptr<T> last,
const U &value) {
fmt::print("Fill async...\n");
auto &&q = __detail::get_queue_for_pointer(first);
fmt::print("Got queue...\n");
auto *arr = first.get_raw_pointer();
// not using q.fill because of CMPLRLLVM-46438
return dr::__detail::parallel_for(q, sycl::range<>(last - first),
[=](auto idx) { arr[idx] = value; });
}
template <typename T, typename U>
requires(std::indirectly_writable<device_ptr<T>, U>)
void fill(device_ptr<T> first, device_ptr<T> last, const U &value) {
fmt::print("Fill...\n");
fill_async(first, last, value).wait();
fmt::print("Fill.\n");
}
template <typename T, dr::remote_contiguous_range R>
sycl::event fill_async(R &&r, const T &value) {
auto &&q = __detail::queue(dr::ranges::rank(r));
auto *arr = std::to_address(rng::begin(dr::ranges::local(r)));
// not using q.fill because of CMPLRLLVM-46438
return dr::__detail::parallel_for(q, sycl::range<>(rng::distance(r)),
[=](auto idx) { arr[idx] = value; });
}
template <typename T, dr::remote_contiguous_range R>
auto fill(R &&r, const T &value) {
fill_async(r, value).wait();
return rng::end(r);
}
template <typename T, dr::distributed_contiguous_range DR>
sycl::event fill_async(DR &&r, const T &value) {
std::vector<sycl::event> events;
for (auto &&segment : dr::ranges::segments(r)) {
auto e = dr::shp::fill_async(segment, value);
events.push_back(e);
}
return dr::shp::__detail::combine_events(events);
}
template <typename T, dr::distributed_contiguous_range DR>
auto fill(DR &&r, const T &value) {
fill_async(r, value).wait();
return rng::end(r);
}
template <typename T, dr::distributed_iterator Iter>
auto fill(Iter first, Iter last, const T &value) {
fill_async(rng::subrange(first, last), value).wait();
return last;
}
} // namespace dr::shp