diff --git a/velox/exec/Task.cpp b/velox/exec/Task.cpp index 9ae3661962e1..72268cc5fe40 100644 --- a/velox/exec/Task.cpp +++ b/velox/exec/Task.cpp @@ -2703,18 +2703,19 @@ StopReason Task::leaveSuspended(ThreadState& state) { ++numThreads_; } }); - if (state.isTerminated) { - return StopReason::kAlreadyTerminated; - } - if (terminateRequested_) { - state.isTerminated = true; - return StopReason::kTerminate; - } if (state.numSuspensions > 1 || !pauseRequested_) { + if (state.isTerminated) { + return StopReason::kAlreadyTerminated; + } + if (terminateRequested_) { + state.isTerminated = true; + return StopReason::kTerminate; + } // If we have more than one suspension requests on this driver thread or // the task has been resumed, then we return here. return StopReason::kNone; } + VELOX_CHECK_GT(state.numSuspensions, 0); VELOX_CHECK_GE(numThreads_, 0); leaveGuard.dismiss(); diff --git a/velox/exec/Task.h b/velox/exec/Task.h index 4e24c0814463..015600b99008 100644 --- a/velox/exec/Task.h +++ b/velox/exec/Task.h @@ -1120,7 +1120,7 @@ class Task : public std::enable_shared_from_this { // queued split groups. std::queue queuedSplitGroups_; - TaskState state_ = TaskState::kRunning; + TaskState state_{TaskState::kRunning}; // Stores splits state structure for each plan node. At construction populated // with all leaf plan nodes that require splits. Afterwards accessed with diff --git a/velox/exec/tests/DriverTest.cpp b/velox/exec/tests/DriverTest.cpp index f37940fdfb3a..2ec28809c072 100644 --- a/velox/exec/tests/DriverTest.cpp +++ b/velox/exec/tests/DriverTest.cpp @@ -1416,6 +1416,49 @@ DEBUG_ONLY_TEST_F(DriverTest, driverSuspensionCalledFromOffThread) { VELOX_ASSERT_THROW(driver->task()->leaveSuspended(driver->state()), ""); } +// This test case verifies that the driver thread leaves suspended state after +// task termiates and before resuming. +DEBUG_ONLY_TEST_F(DriverTest, driverSuspendedAfterTaskTerminateBeforeResume) { + std::shared_ptr driver; + std::atomic_bool triggerSuspended{false}; + std::atomic_bool taskPaused{false}; + // std::atomic_bool driverExecutionWaitFlag{true}; + folly::EventCount taskPausedWait; + std::atomic_bool driverLeaveSuspended{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Values::getOutput", + std::function([&](const exec::Values* values) { + if (triggerSuspended.exchange(true)) { + return; + } + driver = values->testingOperatorCtx()->driver()->shared_from_this(); + driver->task()->enterSuspended(driver->state()); + driver->task()->requestPause().wait(); + taskPaused = true; + taskPausedWait.notifyAll(); + const StopReason ret = driver->task()->leaveSuspended(driver->state()); + ASSERT_EQ(ret, StopReason::kAlreadyTerminated); + driverLeaveSuspended = true; + })); + + auto task = createAndStartTaskToReadValues(1); + + taskPausedWait.await([&]() { return taskPaused.load(); }); + task->requestCancel().wait(); + // Wait for 1 second and check the driver is still under suspended state + // without resuming. + std::this_thread::sleep_for(std::chrono::milliseconds(1'000)); + ASSERT_FALSE(driverLeaveSuspended); + + Task::resume(task); + std::this_thread::sleep_for(std::chrono::milliseconds(1'000)); + // Check the driver leaves the suspended state after task is resumed. Wait for + // 1 second to avoid timing flakiness. + ASSERT_TRUE(driverLeaveSuspended); + + ASSERT_TRUE(waitForTaskCancelled(task.get(), 100'000'000)); +} + DEBUG_ONLY_TEST_F(DriverTest, driverThreadContext) { ASSERT_TRUE(driverThreadContext() == nullptr); std::thread nonDriverThread(