Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions runtime/executor/method.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,11 @@ Error Method::experimental_step() {
return step();
}

bool Method::in_progress() const {
return (step_state_.chain_idx != 0 || step_state_.instr_idx != 0) &&
step_state_.chain_idx < n_chains_;
}

Error Method::execute() {
internal::event_tracer_create_event_block(event_tracer_, "Execute");
EventTracerEntry event_tracer_entry =
Expand All @@ -1584,6 +1589,8 @@ Error Method::execute() {
initialized(),
NotSupported,
"Cannot execute until method has been initialized.");
ET_CHECK_OR_RETURN_ERROR(
!in_progress(), InvalidState, "Method execution is in progress");
const size_t n_input = inputs_size();
for (size_t i = 0; i < n_input; ++i) {
ET_CHECK_OR_RETURN_ERROR(
Expand Down Expand Up @@ -1622,6 +1629,7 @@ Error Method::execute() {
static_cast<DebugHandle>(step_state_.instr_idx));
auto status = execute_instruction();
if (status != Error::Ok) {
step_state_ = StepState{0, 0};
return status;
}
}
Expand Down
29 changes: 24 additions & 5 deletions runtime/executor/method.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,12 @@ class Method final {
/**
* Execute the method.
*
* NOTE: Will fail if the method has been partially executed using the
* `step()` api.
* NOTE: Will fail with Error::InvalidState if the method has been partially
* executed using the `step()` api. Call reset_execution() after step()
* reaches EndOfMethod before calling execute().
*
* If execute() encounters an error mid-execution, the state is automatically
* reset to allow retry attempts.
*
* @returns Error::Ok on success, non-Ok on failure.
*/
Expand All @@ -236,14 +240,29 @@ class Method final {
/// DEPRECATED: Use `step()` instead.
ET_DEPRECATED ET_NODISCARD Error experimental_step();

/**
* EXPERIMENTAL: Returns true if method execution is in progress.
*
* Execution is considered "in progress" when step() has been called and
* execution is mid-way through the instruction sequence.
*
* @retval true if execution is in progress (not at initial state, not at end)
* @retval false if at initial state, after completion, or after any error
*
* @note execute() automatically resets state on error, so in_progress()
* returns false immediately after execute() fails.
*/
ET_EXPERIMENTAL ET_NODISCARD bool in_progress() const;

/**
* EXPERIMENTAL: Resets execution state to the start of the Method. For use
* with the `step()` API.
*
* @retval Error:Ok on success
* @retval Error::Ok on success
* @retval Error::InvalidState if called before step-based execution reached
* the end of the Method. This means it is not possible to recover a
* Method that failed mid-execution.
* the end of the Method. When using step(), you must step through to
* EndOfMethod before resetting. Note: execute() handles its own error
* recovery and does not require manual reset.
*/
ET_EXPERIMENTAL ET_NODISCARD Error reset_execution();

Expand Down
90 changes: 90 additions & 0 deletions runtime/executor/test/method_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,96 @@ TEST_F(MethodTest, MethodGetAttributeTest) {
EXPECT_EQ(res->const_data_ptr<int32_t>()[0], 1);
}

TEST_F(MethodTest, InProgressInitialState) {
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
Result<Method> method = programs_["add"]->load_method("forward", &mmm.get());
ASSERT_EQ(method.error(), Error::Ok);

EXPECT_FALSE(method->in_progress());
}

TEST_F(MethodTest, InProgressDuringStepExecution) {
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
Result<Method> method =
programs_["add_mul"]->load_method("forward", &mmm.get());
ASSERT_EQ(method.error(), Error::Ok);

auto input_cleanup = prepare_input_tensors(*method);
ASSERT_EQ(input_cleanup.error(), Error::Ok);

Error err = method->step();
ASSERT_EQ(err, Error::Ok);

EXPECT_TRUE(method->in_progress());

while (err != Error::EndOfMethod) {
err = method->step();
ASSERT_TRUE(err == Error::Ok || err == Error::EndOfMethod);
}

EXPECT_FALSE(method->in_progress());
}

TEST_F(MethodTest, ExecuteFailsWhenInProgress) {
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
Result<Method> method =
programs_["add_mul"]->load_method("forward", &mmm.get());
ASSERT_EQ(method.error(), Error::Ok);

auto input_cleanup = prepare_input_tensors(*method);
ASSERT_EQ(input_cleanup.error(), Error::Ok);

Error err = method->step();
ASSERT_EQ(err, Error::Ok);
ASSERT_TRUE(method->in_progress());

err = method->execute();
EXPECT_EQ(err, Error::InvalidState);
}

TEST_F(MethodTest, ExecuteSucceedsAfterReset) {
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
Result<Method> method = programs_["add"]->load_method("forward", &mmm.get());
ASSERT_EQ(method.error(), Error::Ok);

auto input_cleanup = prepare_input_tensors(*method);
ASSERT_EQ(input_cleanup.error(), Error::Ok);
auto input_err = method->set_input(EValue(1.0), 2);
ASSERT_EQ(input_err, Error::Ok);

Error err = Error::Ok;
while (err != Error::EndOfMethod) {
err = method->step();
ASSERT_TRUE(err == Error::Ok || err == Error::EndOfMethod);
}

err = method->reset_execution();
ASSERT_EQ(err, Error::Ok);
EXPECT_FALSE(method->in_progress());

err = method->execute();
EXPECT_EQ(err, Error::Ok);
}

TEST_F(MethodTest, ExecuteResetsOnError) {
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
Result<Method> method = programs_["add"]->load_method("forward", &mmm.get());
ASSERT_EQ(method.error(), Error::Ok);

Error err = method->execute();
EXPECT_NE(err, Error::Ok);

EXPECT_FALSE(method->in_progress());

auto input_cleanup = prepare_input_tensors(*method);
ASSERT_EQ(input_cleanup.error(), Error::Ok);
auto input_err = method->set_input(EValue(1.0), 2);
ASSERT_EQ(input_err, Error::Ok);

err = method->execute();
EXPECT_EQ(err, Error::Ok);
}

/*
* TODO(T161163608): Test is disabled due to a resize bug in tensor_index_out of
* the portable op lib
Expand Down
Loading