From 23140befade62f3cb3ed79c1d9cf3740cfd1669a Mon Sep 17 00:00:00 2001 From: RJ Ascani Date: Wed, 7 Jan 2026 09:59:29 -0800 Subject: [PATCH] Add in_progress() state check to Method::execute() Adds in_progress() to detect when execution is mid-way through the instruction sequence. Method::execute() now returns InvalidState if called while in_progress() is true. execute() auto-resets step_state_ on error to allow retry attempts. step() does not auto-reset (requires manual reset_execution() call). Tests verify initial state, mid-execution detection, error recovery, and interaction between execute() and step() APIs. --- runtime/executor/method.cpp | 8 +++ runtime/executor/method.h | 29 +++++++-- runtime/executor/test/method_test.cpp | 90 +++++++++++++++++++++++++++ 3 files changed, 122 insertions(+), 5 deletions(-) diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index ccb88a03818..a830c2e5e08 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -1574,6 +1574,11 @@ Error Method::experimental_step() { return step(); } +bool Method::in_progress() const { + return (step_state_.chain_idx != 0 || step_state_.instr_idx != 0) && + step_state_.chain_idx < n_chains_; +} + Error Method::execute() { internal::event_tracer_create_event_block(event_tracer_, "Execute"); EventTracerEntry event_tracer_entry = @@ -1584,6 +1589,8 @@ Error Method::execute() { initialized(), NotSupported, "Cannot execute until method has been initialized."); + ET_CHECK_OR_RETURN_ERROR( + !in_progress(), InvalidState, "Method execution is in progress"); const size_t n_input = inputs_size(); for (size_t i = 0; i < n_input; ++i) { ET_CHECK_OR_RETURN_ERROR( @@ -1622,6 +1629,7 @@ Error Method::execute() { static_cast(step_state_.instr_idx)); auto status = execute_instruction(); if (status != Error::Ok) { + step_state_ = StepState{0, 0}; return status; } } diff --git a/runtime/executor/method.h b/runtime/executor/method.h index 78b71945a5a..cd121a56c17 100644 --- a/runtime/executor/method.h +++ b/runtime/executor/method.h @@ -217,8 +217,12 @@ class Method final { /** * Execute the method. * - * NOTE: Will fail if the method has been partially executed using the - * `step()` api. + * NOTE: Will fail with Error::InvalidState if the method has been partially + * executed using the `step()` api. Call reset_execution() after step() + * reaches EndOfMethod before calling execute(). + * + * If execute() encounters an error mid-execution, the state is automatically + * reset to allow retry attempts. * * @returns Error::Ok on success, non-Ok on failure. */ @@ -236,14 +240,29 @@ class Method final { /// DEPRECATED: Use `step()` instead. ET_DEPRECATED ET_NODISCARD Error experimental_step(); + /** + * EXPERIMENTAL: Returns true if method execution is in progress. + * + * Execution is considered "in progress" when step() has been called and + * execution is mid-way through the instruction sequence. + * + * @retval true if execution is in progress (not at initial state, not at end) + * @retval false if at initial state, after completion, or after any error + * + * @note execute() automatically resets state on error, so in_progress() + * returns false immediately after execute() fails. + */ + ET_EXPERIMENTAL ET_NODISCARD bool in_progress() const; + /** * EXPERIMENTAL: Resets execution state to the start of the Method. For use * with the `step()` API. * - * @retval Error:Ok on success + * @retval Error::Ok on success * @retval Error::InvalidState if called before step-based execution reached - * the end of the Method. This means it is not possible to recover a - * Method that failed mid-execution. + * the end of the Method. When using step(), you must step through to + * EndOfMethod before resetting. Note: execute() handles its own error + * recovery and does not require manual reset. */ ET_EXPERIMENTAL ET_NODISCARD Error reset_execution(); diff --git a/runtime/executor/test/method_test.cpp b/runtime/executor/test/method_test.cpp index c0a58793248..a60740f8066 100644 --- a/runtime/executor/test/method_test.cpp +++ b/runtime/executor/test/method_test.cpp @@ -430,6 +430,96 @@ TEST_F(MethodTest, MethodGetAttributeTest) { EXPECT_EQ(res->const_data_ptr()[0], 1); } +TEST_F(MethodTest, InProgressInitialState) { + ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); + Result method = programs_["add"]->load_method("forward", &mmm.get()); + ASSERT_EQ(method.error(), Error::Ok); + + EXPECT_FALSE(method->in_progress()); +} + +TEST_F(MethodTest, InProgressDuringStepExecution) { + ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); + Result method = + programs_["add_mul"]->load_method("forward", &mmm.get()); + ASSERT_EQ(method.error(), Error::Ok); + + auto input_cleanup = prepare_input_tensors(*method); + ASSERT_EQ(input_cleanup.error(), Error::Ok); + + Error err = method->step(); + ASSERT_EQ(err, Error::Ok); + + EXPECT_TRUE(method->in_progress()); + + while (err != Error::EndOfMethod) { + err = method->step(); + ASSERT_TRUE(err == Error::Ok || err == Error::EndOfMethod); + } + + EXPECT_FALSE(method->in_progress()); +} + +TEST_F(MethodTest, ExecuteFailsWhenInProgress) { + ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); + Result method = + programs_["add_mul"]->load_method("forward", &mmm.get()); + ASSERT_EQ(method.error(), Error::Ok); + + auto input_cleanup = prepare_input_tensors(*method); + ASSERT_EQ(input_cleanup.error(), Error::Ok); + + Error err = method->step(); + ASSERT_EQ(err, Error::Ok); + ASSERT_TRUE(method->in_progress()); + + err = method->execute(); + EXPECT_EQ(err, Error::InvalidState); +} + +TEST_F(MethodTest, ExecuteSucceedsAfterReset) { + ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); + Result method = programs_["add"]->load_method("forward", &mmm.get()); + ASSERT_EQ(method.error(), Error::Ok); + + auto input_cleanup = prepare_input_tensors(*method); + ASSERT_EQ(input_cleanup.error(), Error::Ok); + auto input_err = method->set_input(EValue(1.0), 2); + ASSERT_EQ(input_err, Error::Ok); + + Error err = Error::Ok; + while (err != Error::EndOfMethod) { + err = method->step(); + ASSERT_TRUE(err == Error::Ok || err == Error::EndOfMethod); + } + + err = method->reset_execution(); + ASSERT_EQ(err, Error::Ok); + EXPECT_FALSE(method->in_progress()); + + err = method->execute(); + EXPECT_EQ(err, Error::Ok); +} + +TEST_F(MethodTest, ExecuteResetsOnError) { + ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); + Result method = programs_["add"]->load_method("forward", &mmm.get()); + ASSERT_EQ(method.error(), Error::Ok); + + Error err = method->execute(); + EXPECT_NE(err, Error::Ok); + + EXPECT_FALSE(method->in_progress()); + + auto input_cleanup = prepare_input_tensors(*method); + ASSERT_EQ(input_cleanup.error(), Error::Ok); + auto input_err = method->set_input(EValue(1.0), 2); + ASSERT_EQ(input_err, Error::Ok); + + err = method->execute(); + EXPECT_EQ(err, Error::Ok); +} + /* * TODO(T161163608): Test is disabled due to a resize bug in tensor_index_out of * the portable op lib