Skip to content

Commit

Permalink
Merge pull request #150 from elbeno/periodic
Browse files Browse the repository at this point in the history
🎨 Allow `time_scheduler` to get its expiry time from a receiver
  • Loading branch information
elbeno authored Feb 19, 2025
2 parents b9a275e + 8ba53fb commit 30b574f
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 26 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ target_sources(
include/async/read_env.hpp
include/async/repeat.hpp
include/async/retry.hpp
include/async/schedulers/get_expiration.hpp
include/async/schedulers/inline_scheduler.hpp
include/async/schedulers/priority_scheduler.hpp
include/async/schedulers/requeue_policy.hpp
Expand Down
3 changes: 1 addition & 2 deletions include/async/env.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
#include <utility>

namespace async {
template <typename Query, typename Value> struct prop {
template <typename Query, typename Value> struct prop : Query {
[[nodiscard]] constexpr auto query(Query) const noexcept -> Value const & {
return value;
}

[[no_unique_address]] Query _{};
[[no_unique_address]] Value value{};
};
template <typename Query, typename Value>
Expand Down
22 changes: 22 additions & 0 deletions include/async/schedulers/get_expiration.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once

#include <async/forwarding_query.hpp>

#include <stdx/ct_string.hpp>

#include <utility>

namespace async {
namespace timer_mgr {
constexpr inline struct get_expiration_t : forwarding_query_t {
constexpr static auto name = stdx::ct_string{"get_expiration"};

template <typename T>
constexpr auto operator()(T &&t) const noexcept(
noexcept(std::forward<T>(t).query(std::declval<get_expiration_t>())))
-> decltype(std::forward<T>(t).query(*this)) {
return std::forward<T>(t).query(*this);
}
} get_expiration{};
} // namespace timer_mgr
} // namespace async
99 changes: 88 additions & 11 deletions include/async/schedulers/time_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
#include <async/connect.hpp>
#include <async/debug.hpp>
#include <async/env.hpp>
#include <async/schedulers/get_expiration.hpp>
#include <async/schedulers/timer_manager_interface.hpp>
#include <async/type_traits.hpp>

#include <stdx/concepts.hpp>
#include <stdx/ct_string.hpp>

#include <memory>
#include <optional>
Expand All @@ -31,14 +33,27 @@ struct op_state_base : Task {
[[no_unique_address]] Rcvr rcvr;
};

struct DurationExpirationPolicy {
template <typename Domain, typename Task> static auto schedule(Task &&t) {
detail::run_after<Domain>(std::forward<Task>(t), t.d);
}
};

struct TimepointExpirationPolicy {
template <typename Domain, typename Task> static auto schedule(Task &&t) {
detail::run_at<Domain>(std::forward<Task>(t),
get_expiration(get_env(t.rcvr)));
}
};

template <typename Domain, stdx::ct_string Name, typename Duration,
typename Rcvr, typename Task>
typename Rcvr, typename Task, typename ExpirationPolicy>
struct op_state;

template <typename Domain, stdx::ct_string Name, typename Duration,
typename Rcvr, typename Task>
typename Rcvr, typename Task, typename ExpirationPolicy>
requires unstoppable_token<stop_token_of_t<env_of_t<Rcvr>>>
struct op_state<Domain, Name, Duration, Rcvr, Task> final
struct op_state<Domain, Name, Duration, Rcvr, Task, ExpirationPolicy> final
: op_state_base<Rcvr, Name, Task> {
template <stdx::same_as_unqualified<Rcvr> R>
// NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
Expand All @@ -49,16 +64,16 @@ struct op_state<Domain, Name, Duration, Rcvr, Task> final
debug_signal<"start", debug::erased_context_for<
op_state_base<Rcvr, Name, Task>>>(
get_env(this->rcvr));
detail::run_after<Domain>(*this, d);
ExpirationPolicy::template schedule<Domain>(*this);
}

[[no_unique_address]] Duration d{};
};

template <typename Domain, stdx::ct_string Name, typename Duration,
typename Rcvr, typename Task>
typename Rcvr, typename Task, typename ExpirationPolicy>
requires(not unstoppable_token<stop_token_of_t<env_of_t<Rcvr>>>)
struct op_state<Domain, Name, Duration, Rcvr, Task> final
struct op_state<Domain, Name, Duration, Rcvr, Task, ExpirationPolicy> final
: op_state_base<Rcvr, Name, Task> {
template <stdx::same_as_unqualified<Rcvr> R>
// NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
Expand All @@ -76,11 +91,13 @@ struct op_state<Domain, Name, Duration, Rcvr, Task> final
get_env(this->rcvr));
set_stopped(std::move(this->rcvr));
} else {
detail::run_after<Domain>(*this, d);
ExpirationPolicy::template schedule<Domain>(*this);
stop_cb.emplace(token, stop_callback_fn{this});
}
}

[[no_unique_address]] Duration d{};

private:
struct stop_callback_fn {
auto operator()() -> void {
Expand All @@ -97,7 +114,6 @@ struct op_state<Domain, Name, Duration, Rcvr, Task> final
using stop_callback_t =
stop_callback_for_t<stop_token_of_t<env_of_t<Rcvr>>, stop_callback_fn>;

[[no_unique_address]] Duration d{};
std::optional<stop_callback_t> stop_cb{};
};
} // namespace timer_mgr
Expand Down Expand Up @@ -132,7 +148,8 @@ class time_scheduler {
[[nodiscard]] constexpr auto connect(R &&r) const & {
check_connect<sender, R>();
return timer_mgr::op_state<Domain, Name, Duration,
std::remove_cvref_t<R>, Task>{
std::remove_cvref_t<R>, Task,
timer_mgr::DurationExpirationPolicy>{
std::forward<R>(r), d};
}
};
Expand All @@ -151,6 +168,66 @@ class time_scheduler {
[[no_unique_address]] Duration d{};
};

namespace detail {
struct no_duration_t {};
struct no_task_t;

template <typename Env> constexpr auto query_expiration(Env const &e) {
if constexpr (async::detail::valid_query_for<timer_mgr::get_expiration_t,
Env>) {
return timer_mgr::get_expiration(e);
} else {
return 0;
}
}
} // namespace detail

template <typename Domain, stdx::ct_string Name>
class time_scheduler<Domain, Name, detail::no_duration_t, detail::no_task_t> {
struct sender {
using is_sender = void;

[[nodiscard]] constexpr auto query(get_env_t) const noexcept {
return prop{get_completion_scheduler_t<set_value_t>{},
time_scheduler{}};
}

template <typename Env>
[[nodiscard]] constexpr static auto
get_completion_signatures(Env const &) noexcept
-> completion_signatures<set_value_t(), set_stopped_t()> {
return {};
}

template <typename Env>
requires unstoppable_token<stop_token_of_t<Env>>
[[nodiscard]] constexpr static auto get_completion_signatures(
Env const &) noexcept -> completion_signatures<set_value_t()> {
return {};
}

template <receiver R>
[[nodiscard]] constexpr auto connect(R &&r) const & {
check_connect<sender, R>();
using TP = decltype(detail::query_expiration(get_env(r)));
using task_t = timer_task<TP>;
return timer_mgr::op_state<Domain, Name, detail::no_duration_t,
std::remove_cvref_t<R>, task_t,
timer_mgr::TimepointExpirationPolicy>{
std::forward<R>(r), {}};
}
};

[[nodiscard]] friend constexpr auto
operator==(time_scheduler, time_scheduler) -> bool = default;

public:
[[nodiscard]] constexpr auto schedule() -> sender { return {}; }
};

time_scheduler() -> time_scheduler<timer_mgr::default_domain, "time_scheduler",
detail::no_duration_t, detail::no_task_t>;

template <typename D>
time_scheduler(D)
-> time_scheduler<timer_mgr::default_domain, "time_scheduler", D>;
Expand All @@ -171,8 +248,8 @@ struct debug::context_for<timer_mgr::op_state_base<Rcvr, Name, Task>> {
};

template <typename Domain, stdx::ct_string Name, typename Duration,
typename Rcvr, typename Task>
typename Rcvr, typename Task, typename ExpirationPolicy>
struct debug::context_for<
timer_mgr::op_state<Domain, Name, Duration, Rcvr, Task>>
timer_mgr::op_state<Domain, Name, Duration, Rcvr, Task, ExpirationPolicy>>
: debug::context_for<timer_mgr::op_state_base<Rcvr, Name, Task>> {};
} // namespace async
46 changes: 35 additions & 11 deletions include/async/schedulers/timer_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,27 @@ template <detail::timer_hal H> struct generic_timer_manager {
stdx::intrusive_list<task_t> task_queue{};
stdx::atomic<int> task_count;

auto schedule(task_t *t, duration_t d) -> void {
auto enqueue(task_t *t) -> void {
auto pos = std::find_if(std::begin(task_queue), std::end(task_queue),
[&](auto const &task) { return *t < task; });
if (pos == std::begin(task_queue)) {
H::set_event_time(t->expiration_time);
}
task_queue.insert(pos, t);
}

auto schedule_at(task_t *t, time_point_t tp) -> void {
t->expiration_time = tp;
if (std::empty(task_queue)) {
task_queue.push_back(t);
H::enable();
H::set_event_time(t->expiration_time);
} else {
enqueue(t);
}
}

auto schedule_after(task_t *t, duration_t d) -> void {
if (std::empty(task_queue)) {
task_queue.push_back(t);
if constexpr (detail::fused_enable_timer_hal<H>) {
Expand All @@ -65,15 +85,7 @@ template <detail::timer_hal H> struct generic_timer_manager {
}
} else {
t->expiration_time = H::now() + d;
auto pos = std::find_if(
std::begin(task_queue), std::end(task_queue),
[&](auto const &task) {
return task.expiration_time > t->expiration_time;
});
if (pos == std::begin(task_queue)) {
H::set_event_time(t->expiration_time);
}
task_queue.insert(pos, t);
enqueue(t);
}
}

Expand All @@ -93,7 +105,19 @@ template <detail::timer_hal H> struct generic_timer_manager {
return conc::call_in_critical_section<mutex>([&]() -> bool {
if (auto const added = not std::exchange(t.pending, true); added) {
++task_count;
schedule(std::addressof(t), d);
schedule_after(std::addressof(t), d);
return true;
}
return false;
});
}

template <std::derived_from<task_t> T, std::convertible_to<time_point_t> TP>
auto run_at(T &t, TP tp) -> bool {
return conc::call_in_critical_section<mutex>([&]() -> bool {
if (auto const added = not std::exchange(t.pending, true); added) {
++task_count;
schedule_at(std::addressof(t), tp);
return true;
}
return false;
Expand Down
15 changes: 15 additions & 0 deletions include/async/schedulers/timer_manager_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ struct undefined_timer_manager {
return false;
}

template <typename... Args> static auto run_at(Args &&...) -> bool {
static_assert(stdx::always_false_v<Args...>,
"Inject a timer manager by specializing "
"async::injected_timer_manager.");
return false;
}

template <typename... Args> static auto service_task(Args &&...) -> void {
static_assert(stdx::always_false_v<Args...>,
"Inject a timer manager by specializing "
Expand Down Expand Up @@ -97,6 +104,14 @@ auto run_after(Args &&...args) -> bool {
std::forward<Args>(args)...);
}

template <typename Domain = default_domain, typename... DummyArgs,
typename... Args>
requires(sizeof...(DummyArgs) == 0)
auto run_at(Args &&...args) -> bool {
return get_injected_manager<Domain, DummyArgs...>().run_at(
std::forward<Args>(args)...);
}

template <typename Domain = default_domain, typename... DummyArgs,
typename... Args>
requires(sizeof...(DummyArgs) == 0)
Expand Down
2 changes: 1 addition & 1 deletion test/detail/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ template <typename T, typename Env> struct with_env : T {
[[no_unique_address]] Env e;

[[nodiscard]] constexpr auto query(async::get_env_t) const noexcept {
return e;
return async::env{e, async::get_env(static_cast<T const &>(*this))};
}
};
template <typename T, typename Env> with_env(T, Env) -> with_env<T, Env>;
Expand Down
38 changes: 38 additions & 0 deletions test/schedulers/time_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,41 @@ TEST_CASE("time_scheduler produces set_stopped debug signal",
CHECK(debug_events ==
std::vector{"op sched start"s, "op sched set_stopped"s});
}

TEST_CASE("time_scheduler with no argument produces a scheduler that gets its "
"expiration time externally",
"[time_scheduler]") {
auto s = async::time_scheduler{};
int var{};
async::sender auto sndr =
async::start_on(s, async::just_result_of([&] { var = 42; }));
auto r = with_env{universal_receiver{},
async::prop{async::timer_mgr::get_expiration,
std::chrono::steady_clock::time_point{}}};
auto op = async::connect(sndr, r);

async::start(op);
CHECK(enabled<default_domain>);
async::timer_mgr::service_task();
CHECK(var == 42);
CHECK(async::timer_mgr::is_idle());
CHECK(not enabled<default_domain>);
}

TEST_CASE("time_scheduler with no argument is cancellable",
"[time_scheduler]") {
auto s = async::time_scheduler{};
int var{};
async::sender auto sndr =
async::start_on(s, async::just_result_of([&] { var = 42; }));
auto r = with_env{stoppable_receiver{[&] { var = 17; }},
async::prop{async::timer_mgr::get_expiration,
std::chrono::steady_clock::time_point{}}};
auto op = async::connect(sndr, r);

r.request_stop();
async::start(op);
CHECK(not enabled<default_domain>);
CHECK(var == 17);
CHECK(async::timer_mgr::is_idle());
}
13 changes: 12 additions & 1 deletion test/schedulers/timer_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ TEST_CASE("nothing pending", "[timer_manager]") {
CHECK(m.is_idle());
}

TEST_CASE("queue a task", "[timer_manager]") {
TEST_CASE("queue a task (run_after)", "[timer_manager]") {
hal::calls.clear();
auto t = timer_manager_t::create_task([] {});

Expand All @@ -90,6 +90,17 @@ TEST_CASE("queue a task", "[timer_manager]") {
CHECK(hal::calls[0] == 3);
}

TEST_CASE("queue a task (run_at)", "[timer_manager]") {
hal::calls.clear();
auto t = timer_manager_t::create_task([] {});

auto m = timer_manager_t{};
m.run_at(t, 3);
CHECK(not m.is_idle());
REQUIRE(hal::calls.size() == 1);
CHECK(hal::calls[0] == 3);
}

TEST_CASE("run a queued task", "[timer_manager]") {
hal::calls.clear();

Expand Down

0 comments on commit 30b574f

Please sign in to comment.