diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 1407a2b533..84b827ccc6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -47,9 +47,13 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.script.ScriptService; +import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.client.Client; +import lombok.Builder; + public interface RemoteConnectorExecutor { + public String RETRY_EXECUTOR = "opensearch_ml_predict_remote"; default void executeAction(String action, MLInput mlInput, ActionListener actionListener) { @@ -253,38 +257,23 @@ default void invokeRemoteServiceWithRetry( ExecutionContext executionContext, ActionListener> actionListener ) { - final RetryableAction> invokeRemoteModelAction = new RetryableAction<>( + final RetryableAction> invokeRemoteModelAction = new RetryableActionExtension( getLogger(), getClient().threadPool(), TimeValue.timeValueMillis(getConnectorClientConfig().getRetryBackoffMillis()), TimeValue.timeValueSeconds(getConnectorClientConfig().getRetryTimeoutSeconds()), actionListener, getRetryBackoffPolicy(getConnectorClientConfig()), - RETRY_EXECUTOR - ) { - int retryTimes = 0; - - @Override - public void tryAction(ActionListener> listener) { - // the listener here is RetryingListener - // If the request success, or can not retry, will call delegate listener - invokeRemoteService(action, mlInput, parameters, payload, executionContext, listener); - } - - @Override - public boolean shouldRetry(Exception e) { - Throwable cause = ExceptionsHelper.unwrapCause(e); - Integer maxRetryTimes = getConnectorClientConfig().getMaxRetryTimes(); - boolean shouldRetry = cause instanceof RemoteConnectorThrottlingException; - if (++retryTimes > maxRetryTimes && maxRetryTimes != -1) { - shouldRetry = false; - } - if (shouldRetry) { - getLogger().debug(String.format(Locale.ROOT, "The %d-th retry for invoke remote model", retryTimes), e); - } - return shouldRetry; - } - }; + RetryableActionExtensionArgs + .builder() + .connectionExecutor(this) + .mlInput(mlInput) + .action(action) + .parameters(parameters) + .executionContext(executionContext) + .payload(payload) + .build() + ); invokeRemoteModelAction.run(); }; @@ -296,4 +285,56 @@ void invokeRemoteService( ExecutionContext executionContext, ActionListener> actionListener ); + + static class RetryableActionExtension extends RetryableAction> { + private final RetryableActionExtensionArgs args; + int retryTimes = 0; + + RetryableActionExtension( + Logger logger, + ThreadPool threadPool, + TimeValue initialDelay, + TimeValue timeoutValue, + ActionListener> listener, + BackoffPolicy backoffPolicy, + RetryableActionExtensionArgs args + ) { + super(logger, threadPool, initialDelay, timeoutValue, listener, backoffPolicy, RETRY_EXECUTOR); + this.args = args; + } + + @Override + public void tryAction(ActionListener> listener) { + // the listener here is RetryingListener + // If the request success, or can not retry, will call delegate listener + args.connectionExecutor + .invokeRemoteService(args.action, args.mlInput, args.parameters, args.payload, args.executionContext, listener); + } + + @Override + public boolean shouldRetry(Exception e) { + Throwable cause = ExceptionsHelper.unwrapCause(e); + Integer maxRetryTimes = args.connectionExecutor.getConnectorClientConfig().getMaxRetryTimes(); + boolean shouldRetry = cause instanceof RemoteConnectorThrottlingException; + if (++retryTimes > maxRetryTimes && maxRetryTimes != -1) { + shouldRetry = false; + } + if (shouldRetry) { + args.connectionExecutor + .getLogger() + .debug(String.format(Locale.ROOT, "The %d-th retry for invoke remote model", retryTimes), e); + } + return shouldRetry; + } + } + + @Builder + class RetryableActionExtensionArgs { + private final RemoteConnectorExecutor connectionExecutor; + private final MLInput mlInput; + private final String action; + private final Map parameters; + private final ExecutionContext executionContext; + private final String payload; + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor_RetryableActionExtensionTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor_RetryableActionExtensionTest.java new file mode 100644 index 0000000000..af303ba293 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor_RetryableActionExtensionTest.java @@ -0,0 +1,114 @@ +package org.opensearch.ml.engine.algorithms.remote; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Map; +import java.util.function.Supplier; + +import org.apache.logging.log4j.Logger; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.action.bulk.BackoffPolicy; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.connector.ConnectorClientConfig; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor.RetryableActionExtension; +import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor.RetryableActionExtensionArgs; +import org.opensearch.threadpool.ThreadPool; + +@RunWith(MockitoJUnitRunner.class) +public class RemoteConnectorExecutor_RetryableActionExtensionTest { + + private static final int TEST_ATTEMPT_LIMIT = 5; + + @Mock + Logger logger; + @Mock + ThreadPool threadPool; + @Mock + TimeValue initialDelay; + @Mock + TimeValue timeoutValue; + @Mock + ActionListener> listener; + @Mock + BackoffPolicy backoffPolicy; + @Mock + ConnectorClientConfig connectorClientConfig; + @Mock + RemoteConnectorExecutor connectionExecutor; + + RetryableActionExtension retryableAction; + + @Before + public void setup() { + when(connectionExecutor.getConnectorClientConfig()).thenReturn(connectorClientConfig); + when(connectionExecutor.getLogger()).thenReturn(logger); + var args = RetryableActionExtensionArgs.builder() + .action("action") + .connectionExecutor(connectionExecutor) + .mlInput(mock(MLInput.class)) + .parameters(Map.of()) + .executionContext(mock(ExecutionContext.class)) + .payload("payload") + .build(); + var settings = Settings.builder().put("node.name", "test").build(); + retryableAction = new RetryableActionExtension(logger, new ThreadPool(settings), TimeValue.timeValueMillis(5), TimeValue.timeValueMillis(500), listener, backoffPolicy, args); + } + + @Test + public void test_ShouldRetry_hitLimitOnRetries() { + var attempts = retryAttempts(-1, this::createThrottleException); + + assertThat(attempts, equalTo(TEST_ATTEMPT_LIMIT)); + } + + @Test + @SuppressWarnings("unchecked") + public void test_ShouldRetry_OnlyOnThrottleExceptions() { + var exceptions = mock(Supplier.class); + when(exceptions.get()) + .thenReturn(createThrottleException()) + .thenReturn(createThrottleException()) + .thenReturn(new RuntimeException()); // Stop retrying on 3rd exception + var attempts = retryAttempts(-1, exceptions); + + assertThat(attempts, equalTo(2)); + verify(exceptions, times(3)).get(); + } + + @Test + public void test_ShouldRetry_stopAtMaxAttempts() { + int maxAttempts = 3; + var attempts = retryAttempts(maxAttempts, this::createThrottleException); + + assertThat(attempts, equalTo(maxAttempts)); + } + + private int retryAttempts(int maxAttempts, Supplier exception) { + when(connectorClientConfig.getMaxRetryTimes()).thenReturn(maxAttempts); + int attempt = 0; + boolean shouldRetry; + do { + shouldRetry = retryableAction.shouldRetry(exception.get()); + } while (shouldRetry && ++attempt < TEST_ATTEMPT_LIMIT); + return attempt; + } + + private RemoteConnectorThrottlingException createThrottleException() { + return new RemoteConnectorThrottlingException("Throttle", RestStatus.TOO_MANY_REQUESTS); + } +}