Skip to content

Commit 8357159

Browse files
committed
On graceful shutdown don't block thread while waiting for work to finish
1 parent bce397f commit 8357159

File tree

6 files changed

+157
-57
lines changed

6 files changed

+157
-57
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
import java.util.Set;
6363
import java.util.concurrent.ConcurrentHashMap;
6464
import java.util.concurrent.ConcurrentLinkedDeque;
65-
import java.util.function.Consumer;
65+
import java.util.function.BiConsumer;
6666

6767
import static org.elasticsearch.core.Strings.format;
6868
import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ASSIGNMENT_TASK_ACTION;
@@ -274,20 +274,27 @@ public void gracefullyStopDeploymentAndNotify(
274274
public void stopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason, ActionListener<AcknowledgedResponse> listener) {
275275
logger.debug(() -> format("[%s] Forcefully stopping deployment due to reason %s", task.getDeploymentId(), reason));
276276

277-
stopAndNotifyHelper(task, reason, listener, deploymentManager::stopDeployment);
277+
stopAndNotifyHelper(task, reason, listener, (t, l) -> {
278+
deploymentManager.stopDeployment(t);
279+
l.onResponse(AcknowledgedResponse.TRUE);
280+
});
278281
}
279282

280283
private void stopAndNotifyHelper(
281284
TrainedModelDeploymentTask task,
282285
String reason,
283286
ActionListener<AcknowledgedResponse> listener,
284-
Consumer<TrainedModelDeploymentTask> stopDeploymentFunc
287+
BiConsumer<TrainedModelDeploymentTask, ActionListener<AcknowledgedResponse>> stopDeploymentFunc
285288
) {
286289
// Removing the entry from the map to avoid the possibility of a node shutdown triggering a concurrent graceful stopping of the
287290
// process while we are attempting to forcefully stop the native process
288291
// The graceful stopping will only occur if there is an entry in the map
289292
deploymentIdToTask.remove(task.getDeploymentId());
290-
ActionListener<Void> notifyDeploymentOfStopped = updateRoutingStateToStoppedListener(task.getDeploymentId(), reason, listener);
293+
ActionListener<AcknowledgedResponse> notifyDeploymentOfStopped = updateRoutingStateToStoppedListener(
294+
task.getDeploymentId(),
295+
reason,
296+
listener
297+
);
291298

292299
updateStoredState(
293300
task.getDeploymentId(),
@@ -541,7 +548,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) {
541548
)
542549
);
543550

544-
ActionListener<Void> notifyDeploymentOfStopped = updateRoutingStateToStoppedListener(
551+
ActionListener<AcknowledgedResponse> notifyDeploymentOfStopped = updateRoutingStateToStoppedListener(
545552
task.getDeploymentId(),
546553
NODE_IS_SHUTTING_DOWN,
547554
routingStateListener
@@ -550,7 +557,7 @@ private void gracefullyStopDeployment(String deploymentId, String currentNode) {
550557
stopDeploymentAfterCompletingPendingWorkAsync(task, notifyDeploymentOfStopped);
551558
}
552559

553-
private ActionListener<Void> updateRoutingStateToStoppedListener(
560+
private ActionListener<AcknowledgedResponse> updateRoutingStateToStoppedListener(
554561
String deploymentId,
555562
String reason,
556563
ActionListener<AcknowledgedResponse> listener
@@ -587,34 +594,40 @@ private void stopUnreferencedDeployment(String deploymentId, String currentNode)
587594
);
588595
}
589596

590-
private void stopDeploymentAsync(TrainedModelDeploymentTask task, String reason, ActionListener<Void> listener) {
591-
stopDeploymentHelper(task, reason, deploymentManager::stopDeployment, listener);
597+
private void stopDeploymentAsync(TrainedModelDeploymentTask task, String reason, ActionListener<AcknowledgedResponse> listener) {
598+
stopDeploymentHelper(task, reason, (t, l) -> {
599+
deploymentManager.stopDeployment(t);
600+
l.onResponse(AcknowledgedResponse.TRUE);
601+
}, listener);
592602
}
593603

594604
private void stopDeploymentHelper(
595605
TrainedModelDeploymentTask task,
596606
String reason,
597-
Consumer<TrainedModelDeploymentTask> stopDeploymentFunc,
598-
ActionListener<Void> listener
607+
BiConsumer<TrainedModelDeploymentTask, ActionListener<AcknowledgedResponse>> stopDeploymentFunc,
608+
ActionListener<AcknowledgedResponse> listener
599609
) {
600610
if (stopped) {
611+
listener.onResponse(AcknowledgedResponse.FALSE);
601612
return;
602613
}
603614
task.markAsStopped(reason);
604615

605616
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
606617
try {
607-
stopDeploymentFunc.accept(task);
608618
taskManager.unregister(task);
609619
deploymentIdToTask.remove(task.getDeploymentId());
610-
listener.onResponse(null);
620+
stopDeploymentFunc.accept(task, listener);
611621
} catch (Exception e) {
612622
listener.onFailure(e);
613623
}
614624
});
615625
}
616626

617-
private void stopDeploymentAfterCompletingPendingWorkAsync(TrainedModelDeploymentTask task, ActionListener<Void> listener) {
627+
private void stopDeploymentAfterCompletingPendingWorkAsync(
628+
TrainedModelDeploymentTask task,
629+
ActionListener<AcknowledgedResponse> listener
630+
) {
618631
stopDeploymentHelper(task, NODE_IS_SHUTTING_DOWN, deploymentManager::stopAfterCompletingPendingWork, listener);
619632
}
620633

@@ -758,6 +771,7 @@ private void handleLoadSuccess(ActionListener<Boolean> retryListener, TrainedMod
758771

759772
private void updateStoredState(String deploymentId, RoutingInfoUpdate update, ActionListener<AcknowledgedResponse> listener) {
760773
if (stopped) {
774+
listener.onResponse(AcknowledgedResponse.FALSE);
761775
return;
762776
}
763777
trainedModelAssignmentService.updateModelAssignmentState(

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -331,15 +331,15 @@ public void stopDeployment(TrainedModelDeploymentTask task) {
331331
}
332332
}
333333

334-
public void stopAfterCompletingPendingWork(TrainedModelDeploymentTask task) {
334+
public void stopAfterCompletingPendingWork(TrainedModelDeploymentTask task, ActionListener<AcknowledgedResponse> listener) {
335335
ProcessContext processContext = processContextByAllocation.remove(task.getId());
336336
if (processContext != null) {
337337
logger.info(
338338
"[{}] Stopping deployment after completing pending tasks, reason [{}]",
339339
task.getDeploymentId(),
340340
task.stoppedReason().orElse("unknown")
341341
);
342-
processContext.stopProcessAfterCompletingPendingWork();
342+
processContext.stopProcessAfterCompletingPendingWork(listener);
343343
} else {
344344
logger.warn("[{}] No process context to stop gracefully", task.getDeploymentId());
345345
}
@@ -569,7 +569,7 @@ private Consumer<String> onProcessCrashHandleRestarts(AtomicInteger startsCount,
569569

570570
processContextByAllocation.remove(task.getId());
571571
isStopped = true;
572-
resultProcessor.stop();
572+
resultProcessor.signalIntentToStop();
573573
stateStreamer.cancel();
574574

575575
if (startsCount.get() <= NUM_RESTART_ATTEMPTS) {
@@ -648,7 +648,7 @@ synchronized void forcefullyStopProcess() {
648648

649649
private void prepareInternalStateForShutdown() {
650650
isStopped = true;
651-
resultProcessor.stop();
651+
resultProcessor.signalIntentToStop();
652652
stateStreamer.cancel();
653653
}
654654

@@ -669,43 +669,33 @@ private void closeNlpTaskProcessor() {
669669
}
670670
}
671671

672-
private synchronized void stopProcessAfterCompletingPendingWork() {
672+
private synchronized void stopProcessAfterCompletingPendingWork(ActionListener<AcknowledgedResponse> listener) {
673673
logger.debug(() -> format("[%s] Stopping process after completing its pending work", task.getDeploymentId()));
674674
prepareInternalStateForShutdown();
675-
signalAndWaitForWorkerTermination();
676-
stopProcessGracefully();
677-
closeNlpTaskProcessor();
678-
}
679-
680-
private void signalAndWaitForWorkerTermination() {
681-
try {
682-
awaitTerminationAfterCompletingWork();
683-
} catch (TimeoutException e) {
684-
logger.warn(format("[%s] Timed out waiting for process worker to complete, forcing a shutdown", task.getDeploymentId()), e);
685-
// The process failed to stop in the time period allotted, so we'll mark it for shut down
686-
priorityProcessWorker.shutdown();
687-
priorityProcessWorker.notifyQueueRunnables();
688-
}
689-
}
690-
691-
private void awaitTerminationAfterCompletingWork() throws TimeoutException {
692-
try {
693-
priorityProcessWorker.shutdown();
694-
695-
if (priorityProcessWorker.awaitTermination(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES) == false) {
696-
throw new TimeoutException(
697-
Strings.format("Timed out waiting for process worker to complete for process %s", PROCESS_NAME)
698-
);
699-
} else {
700-
priorityProcessWorker.notifyQueueRunnables();
701-
}
702-
} catch (InterruptedException e) {
703-
Thread.currentThread().interrupt();
704-
logger.info(Strings.format("[%s] Interrupted waiting for process worker to complete", PROCESS_NAME));
705-
}
706-
}
707675

708-
private void stopProcessGracefully() {
676+
// Waiting for the process worker to finish the pending work could
677+
// take a long time. Best not to block the thread so register
678+
// a function with the process worker that is called when the
679+
// work is finished. Then proceed to closing the native process
680+
// and wait for all results to be processed, the second part can be
681+
// done synchronously as it is not expected to take long.
682+
// The ShutdownTracker will handle this.
683+
684+
// Shutdown tracker will stop the process work and start a race with
685+
// a timeout condition.
686+
new ShutdownTracker(() -> {
687+
// Stopping the process worker timed out, kill the process
688+
logger.warn(format("[%s] Timed out waiting for process worker to complete, forcing a shutdown", task.getDeploymentId()));
689+
forcefullyStopProcess();
690+
}, () -> {
691+
// process worker stopped within allotted time, close process
692+
closeProcessAndWaitForResultProcessor();
693+
closeNlpTaskProcessor();
694+
}, threadPool, priorityProcessWorker, listener);
695+
696+
}
697+
698+
private void closeProcessAndWaitForResultProcessor() {
709699
try {
710700
closeProcessIfPresent();
711701
resultProcessor.awaitCompletion(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES);
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.ml.inference.deployment;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.action.support.master.AcknowledgedResponse;
12+
import org.elasticsearch.core.TimeValue;
13+
import org.elasticsearch.threadpool.Scheduler;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.ml.MachineLearning;
16+
import org.elasticsearch.xpack.ml.inference.pytorch.PriorityProcessWorkerExecutorService;
17+
18+
import java.util.concurrent.atomic.AtomicBoolean;
19+
20+
public class ShutdownTracker {
21+
22+
private final ActionListener<AcknowledgedResponse> everythingStoppedListener;
23+
24+
private final Scheduler.Cancellable timeoutHandler;
25+
private final Runnable onWorkerQueueCompletedCallback;
26+
private final Runnable onTimeoutCallback;
27+
private final Object monitor = new Object();
28+
private final AtomicBoolean timedOutOrCompleted = new AtomicBoolean();
29+
30+
private static final TimeValue COMPLETION_TIMEOUT = TimeValue.timeValueMinutes(5);
31+
32+
public ShutdownTracker(
33+
Runnable onTimeoutCallback,
34+
Runnable onWorkerQueueCompletedCallback,
35+
ThreadPool threadPool,
36+
PriorityProcessWorkerExecutorService workerQueue,
37+
ActionListener<AcknowledgedResponse> everythingStoppedListener
38+
) {
39+
this.onTimeoutCallback = onTimeoutCallback;
40+
this.onWorkerQueueCompletedCallback = onWorkerQueueCompletedCallback;
41+
this.everythingStoppedListener = ActionListener.notifyOnce(everythingStoppedListener);
42+
43+
// initiate the worker shutdown and add this as a callback when completed
44+
workerQueue.shutdownWithCallback(this::workerQueueCompleted);
45+
// start the race with the timeout and the worker completing
46+
this.timeoutHandler = threadPool.schedule(
47+
this::onTimeout,
48+
COMPLETION_TIMEOUT,
49+
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)
50+
);
51+
}
52+
53+
private void onTimeout() {
54+
synchronized (monitor) { // TODO remove the lock as the atomic should be sufficient
55+
if (timedOutOrCompleted.compareAndSet(false, true) == false) {
56+
// already completed
57+
return;
58+
}
59+
onTimeoutCallback.run();
60+
everythingStoppedListener.onResponse(AcknowledgedResponse.FALSE);
61+
}
62+
}
63+
64+
private void workerQueueCompleted() {
65+
synchronized (monitor) { // TODO remove the lock as the atomic should be sufficient
66+
if (timedOutOrCompleted.compareAndSet(false, true) == false) {
67+
// already completed
68+
return;
69+
}
70+
timeoutHandler.cancel();
71+
onWorkerQueueCompletedCallback.run();
72+
everythingStoppedListener.onResponse(AcknowledgedResponse.TRUE);
73+
}
74+
}
75+
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ public synchronized void updateStats(PyTorchResult result) {
313313
}
314314
}
315315

316-
public void stop() {
316+
public void signalIntentToStop() {
317317
isStopping = true;
318318
}
319319

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/AbstractProcessWorkerExecutorService.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public abstract class AbstractProcessWorkerExecutorService<T extends Runnable> e
4444
private final AtomicReference<Exception> error = new AtomicReference<>();
4545
private final AtomicBoolean running = new AtomicBoolean(true);
4646
private final AtomicBoolean shouldShutdownAfterCompletingWork = new AtomicBoolean(false);
47+
private final AtomicReference<Runnable> onCompletion = new AtomicReference<>();
4748

4849
/**
4950
* @param contextHolder the thread context holder
@@ -78,6 +79,11 @@ public void shutdown() {
7879
shouldShutdownAfterCompletingWork.set(true);
7980
}
8081

82+
public void shutdownWithCallback(Runnable onCompletion) {
83+
this.onCompletion.set(onCompletion);
84+
shutdown();
85+
}
86+
8187
/**
8288
* Some of the tasks in the returned list of {@link Runnable}s could have run. Some tasks may have run while the queue was being copied.
8389
*
@@ -124,6 +130,10 @@ public void start() {
124130
} catch (InterruptedException e) {
125131
Thread.currentThread().interrupt();
126132
} finally {
133+
Runnable onComplete = onCompletion.get();
134+
if (onComplete != null) {
135+
onComplete.run();
136+
}
127137
awaitTermination.countDown();
128138
}
129139
}

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,11 @@ public void testLoadQueuedModelsWhenTaskIsStopped() throws Exception {
265265
UpdateTrainedModelAssignmentRoutingInfoAction.Request.class
266266
);
267267
verify(deploymentManager, times(1)).startDeployment(startTaskCapture.capture(), any());
268-
assertBusy(() -> verify(trainedModelAssignmentService, times(3)).updateModelAssignmentState(requestCapture.capture(), any()));
268+
assertBusy(
269+
() -> verify(trainedModelAssignmentService, times(3)).updateModelAssignmentState(requestCapture.capture(), any()),
270+
3,
271+
TimeUnit.SECONDS
272+
);
269273

270274
boolean seenStopping = false;
271275
for (int i = 0; i < 3; i++) {
@@ -398,6 +402,13 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStop
398402
return null;
399403
}).when(trainedModelAssignmentService).updateModelAssignmentState(any(), any());
400404

405+
doAnswer(invocationOnMock -> {
406+
@SuppressWarnings({ "unchecked", "rawtypes" })
407+
ActionListener<AcknowledgedResponse> listener = (ActionListener) invocationOnMock.getArguments()[1];
408+
listener.onResponse(AcknowledgedResponse.TRUE);
409+
return null;
410+
}).when(deploymentManager).stopAfterCompletingPendingWork(any(), any());
411+
401412
var taskParams = newParams(deploymentOne, modelOne);
402413

403414
ClusterChangedEvent event = new ClusterChangedEvent(
@@ -430,7 +441,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStop
430441
fail("Failed waiting for the stop process call to complete");
431442
}
432443

433-
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture());
444+
verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture(), any());
434445
assertThat(stopParamsCapture.getValue().getModelId(), equalTo(modelOne));
435446
assertThat(stopParamsCapture.getValue().getDeploymentId(), equalTo(deploymentOne));
436447
verify(trainedModelAssignmentService, times(1)).updateModelAssignmentState(
@@ -480,7 +491,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_ButOtherA
480491
trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
481492
trainedModelAssignmentNodeService.clusterChanged(event);
482493

483-
verify(deploymentManager, never()).stopAfterCompletingPendingWork(any());
494+
verify(deploymentManager, never()).stopAfterCompletingPendingWork(any(), any());
484495
verify(trainedModelAssignmentService, never()).updateModelAssignmentState(
485496
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
486497
any()
@@ -521,7 +532,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeButAlready
521532

522533
trainedModelAssignmentNodeService.clusterChanged(event);
523534

524-
verify(deploymentManager, never()).stopAfterCompletingPendingWork(any());
535+
verify(deploymentManager, never()).stopAfterCompletingPendingWork(any(), any());
525536
verify(trainedModelAssignmentService, never()).updateModelAssignmentState(
526537
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
527538
any()
@@ -563,7 +574,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeWithStarti
563574
trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
564575
trainedModelAssignmentNodeService.clusterChanged(event);
565576

566-
verify(deploymentManager, never()).stopAfterCompletingPendingWork(any());
577+
verify(deploymentManager, never()).stopAfterCompletingPendingWork(any(), any());
567578
verify(trainedModelAssignmentService, never()).updateModelAssignmentState(
568579
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
569580
any()

0 commit comments

Comments
 (0)