Skip to content

[NFC][SYCL] Simplify variadic_iterator usage #19507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 4 additions & 18 deletions sycl/source/detail/device_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2282,24 +2282,10 @@ class device_impl : public std::enable_shared_from_this<device_impl> {

}; // class device_impl

struct devices_deref_impl {
template <typename T> static device_impl &dereference(T &Elem) {
using Ty = std::decay_t<decltype(Elem)>;
if constexpr (std::is_same_v<Ty, device>) {
return *getSyclObjImpl(Elem);
} else if constexpr (std::is_same_v<Ty, device_impl>) {
return Elem;
} else {
return *Elem;
}
}
};
using devices_iterator =
variadic_iterator<devices_deref_impl, device,
std::vector<std::shared_ptr<device_impl>>::const_iterator,
std::vector<device>::const_iterator,
std::vector<device_impl *>::const_iterator,
device_impl *>;
using devices_iterator = variadic_iterator<
device, std::vector<std::shared_ptr<device_impl>>::const_iterator,
std::vector<device>::const_iterator,
std::vector<device_impl *>::const_iterator, device_impl *>;

class devices_range : public iterator_range<devices_iterator> {
private:
Expand Down
26 changes: 1 addition & 25 deletions sycl/source/detail/graph/node_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,37 +749,13 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
}
};

struct nodes_deref_impl {
template <typename T> static node_impl &dereference(T &Elem) {
using Ty = std::decay_t<decltype(Elem)>;
if constexpr (std::is_same_v<Ty, std::weak_ptr<node_impl>>) {
// This assumes that weak_ptr doesn't actually manage lifetime and
// the object is guaranteed to be alive (which seems to be the
// assumption across all graph code).
return *Elem.lock();
} else if constexpr (std::is_same_v<Ty, node>) {
return *getSyclObjImpl(Elem);
} else {
return *Elem;
}
}
};

template <typename... ContainerTy>
using nodes_iterator_impl =
variadic_iterator<nodes_deref_impl, node,
typename ContainerTy::const_iterator...>;
variadic_iterator<node, typename ContainerTy::const_iterator...>;

using nodes_iterator = nodes_iterator_impl<
std::vector<std::shared_ptr<node_impl>>, std::vector<node_impl *>,
// Next one is temporary. It looks like `weak_ptr`s aren't
// used for the actual lifetime management and the objects are
// always guaranteed to be alive. Once the code is cleaned
// from `weak_ptr`s this alternative should be removed too.
std::vector<std::weak_ptr<node_impl>>,
//
std::set<std::shared_ptr<node_impl>>, std::set<node_impl *>,
//
std::list<node_impl *>, std::vector<node>>;

class nodes_range : public iterator_range<nodes_iterator> {
Expand Down
21 changes: 14 additions & 7 deletions sycl/source/detail/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,19 @@ const RTDeviceBinaryImage *
retrieveKernelBinary(queue_impl &Queue, KernelNameStrRefT KernelName,
CGExecKernel *CGKernel = nullptr);

template <typename DereferenceImpl, typename SyclTy, typename... Iterators>
class variadic_iterator {
template <typename SyclTy, typename... Iterators> class variadic_iterator {
using storage_iter = std::variant<Iterators...>;

storage_iter It;

public:
using iterator_category = std::forward_iterator_tag;
using difference_type = std::ptrdiff_t;
using reference = decltype(DereferenceImpl::dereference(
*std::declval<nth_type_t<0, Iterators...>>()));
using value_type = std::remove_reference_t<reference>;
using value_type = std::remove_reference_t<decltype(*getSyclObjImpl(
std::declval<SyclTy>()))>;
using sycl_type = SyclTy;
using pointer = value_type *;
static_assert(std::is_same_v<reference, value_type &>);
using reference = value_type &;

variadic_iterator(const variadic_iterator &) = default;
variadic_iterator(variadic_iterator &&) = default;
Expand Down Expand Up @@ -79,7 +77,16 @@ class variadic_iterator {
decltype(auto) operator*() {
return std::visit(
[](auto &&It) -> decltype(auto) {
return DereferenceImpl::dereference(*It);
decltype(auto) Elem = *It;
using Ty = std::decay_t<decltype(Elem)>;
static_assert(!std::is_same_v<Ty, decltype(Elem)>);
if constexpr (std::is_same_v<Ty, sycl_type>) {
return *getSyclObjImpl(Elem);
} else if constexpr (std::is_same_v<Ty, value_type>) {
return Elem;
} else {
return *Elem;
}
},
It);
}
Expand Down