diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java index 50af35c70..b17dbe81c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java @@ -7,12 +7,15 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; -import java.util.concurrent.CompletableFuture; +import java.time.Duration; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.regex.Pattern; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + /** * A Server-Sent Events (SSE) client implementation using Java's Flow API for reactive * stream processing. This client establishes a connection to an SSE endpoint and @@ -59,6 +62,17 @@ public class FlowSseClient { */ private static final Pattern EVENT_TYPE_PATTERN = Pattern.compile("^event:(.+)$", Pattern.MULTILINE); + /** + * Atomic reference to hold the current subscription for the SSE stream. + */ + private final AtomicReference currentSubscription = new AtomicReference<>(); + + /** + * Atomic reference to hold the last event ID received from the SSE stream. This can + * be used to resume the stream from the last known event. + */ + private final AtomicReference lastEventId = new AtomicReference<>(); + /** * Record class representing a Server-Sent Event with its standard fields. * @@ -66,7 +80,7 @@ public class FlowSseClient { * @param type the event type (defaults to "message" if not specified in the stream) * @param data the event payload data */ - public static record SseEvent(String id, String type, String data) { + public record SseEvent(String id, String type, String data) { } /** @@ -121,22 +135,35 @@ public FlowSseClient(HttpClient httpClient, HttpRequest.Builder requestBuilder) * @throws RuntimeException if the connection fails with a non-200 status code */ public void subscribe(String url, SseEventHandler eventHandler) { - HttpRequest request = this.requestBuilder.uri(URI.create(url)) - .header("Accept", "text/event-stream") - .header("Cache-Control", "no-cache") - .GET() - .build(); - - StringBuilder eventBuilder = new StringBuilder(); - AtomicReference currentEventId = new AtomicReference<>(); - AtomicReference currentEventType = new AtomicReference<>("message"); + subscribeAsync(url, eventHandler).subscribe(); + } - Flow.Subscriber lineSubscriber = new Flow.Subscriber<>() { + /** + * Subscribes to an SSE endpoint and processes the event stream. + * + *

+ * This method establishes a connection to the specified URL and begins processing the + * SSE stream. Events are parsed and delivered to the provided event handler. The + * connection remains active until either an error occurs or the server closes the + * connection. + * @param url the SSE endpoint URL to connect to + * @param eventHandler the handler that will receive SSE events and error + * notifications + * @return a Mono representing the completion of the subscription + * @throws RuntimeException if the connection fails with a non-200 status code + */ + public Mono subscribeAsync(String url, SseEventHandler eventHandler) { + final Function, HttpResponse.BodySubscriber> subscriberFactory = HttpResponse.BodySubscribers::fromLineSubscriber; + final StringBuilder eventBuilder = new StringBuilder(); + final AtomicReference currentEventId = new AtomicReference<>(); + final AtomicReference currentEventType = new AtomicReference<>("message"); + final Flow.Subscriber lineSubscriber = new Flow.Subscriber<>() { private Flow.Subscription subscription; @Override public void onSubscribe(Flow.Subscription subscription) { this.subscription = subscription; + currentSubscription.set(subscription); subscription.request(Long.MAX_VALUE); } @@ -147,6 +174,7 @@ public void onNext(String line) { if (eventBuilder.length() > 0) { String eventData = eventBuilder.toString(); SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); + lastEventId.set(currentEventId.get()); eventHandler.onEvent(event); eventBuilder.setLength(0); } @@ -190,21 +218,55 @@ public void onComplete() { } }; - Function, HttpResponse.BodySubscriber> subscriberFactory = subscriber -> HttpResponse.BodySubscribers - .fromLineSubscriber(subscriber); + return Mono.defer(() -> { + HttpRequest.Builder builder = this.requestBuilder.uri(URI.create(url)) + .header("Accept", "text/event-stream") + .header("Cache-Control", "no-cache") + .GET(); + + String lastId = lastEventId.get(); + if (lastId != null) { + builder.header("Last-Event-ID", lastId); + } - CompletableFuture> future = this.httpClient.sendAsync(request, - info -> subscriberFactory.apply(lineSubscriber)); + HttpRequest request = builder.build(); - future.thenAccept(response -> { - int status = response.statusCode(); - if (status != 200 && status != 201 && status != 202 && status != 206) { - throw new RuntimeException("Failed to connect to SSE stream. Unexpected status code: " + status); + return Mono + .fromFuture(() -> this.httpClient.sendAsync(request, info -> subscriberFactory.apply(lineSubscriber))) + .flatMap(response -> { + int status = response.statusCode(); + if (status >= 400 && status < 500 && status != 429 && status != 408) { + return Mono.error(new SseConnectionException("Client error." + status, status)); + } + if (status != 200 && status != 201 && status != 202 && status != 206) { + return Mono.error(new SseConnectionException("Failed to connect to SSE stream.", status)); + } + return Mono.empty(); + }) + .doOnError(eventHandler::onError) + .doFinally(sig -> { + Flow.Subscription active = currentSubscription.getAndSet(null); + if (active != null) + active.cancel(); + }) + .then(); + }).retryWhen(Retry.backoff(3, Duration.ofSeconds(2)).filter(err -> { + if (err instanceof SseConnectionException exception) { + return exception.isRetryable(); } - }).exceptionally(throwable -> { - eventHandler.onError(throwable); - return null; - }); + return true; // Retry on other exceptions + }).onRetryExhaustedThrow((spec, signal) -> signal.failure())); + + } + + /** + * Gracefully close the SSE stream subscription if active. + */ + public void close() { + Flow.Subscription subscription = currentSubscription.getAndSet(null); + if (subscription != null) { + subscription.cancel(); + } } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 99cf2a625..75a85b382 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -9,9 +9,6 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.time.Duration; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; @@ -28,6 +25,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; +import reactor.util.retry.Retry; /** * Server-Sent Events (SSE) implementation of the @@ -90,18 +89,12 @@ public class HttpClientSseClientTransport implements McpClientTransport { /** JSON object mapper for message serialization/deserialization */ protected ObjectMapper objectMapper; - /** Flag indicating if the transport is in closing state */ - private volatile boolean isClosing = false; - - /** Latch for coordinating endpoint discovery */ - private final CountDownLatch closeLatch = new CountDownLatch(1); + /** Enum indicating the transport state */ + private final AtomicReference state = new AtomicReference<>(TransportState.DISCONNECTED); /** Holds the discovered message endpoint URL */ private final AtomicReference messageEndpoint = new AtomicReference<>(); - /** Holds the SSE connection future */ - private final AtomicReference> connectionFuture = new AtomicReference<>(); - /** * Creates a new transport instance with default HTTP client and object mapper. * @param baseUri the base URI of the MCP server @@ -338,48 +331,48 @@ public HttpClientSseClientTransport build() { */ @Override public Mono connect(Function, Mono> handler) { - CompletableFuture future = new CompletableFuture<>(); - connectionFuture.set(future); + state.set(TransportState.CONNECTING); + return Mono.create(sink -> subscribeSse(handler, sink)) + .doOnError(err -> logger.error("Error during connection", err)); - URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); + } + + private void subscribeSse(final Function, Mono> handler, MonoSink sink) { + final URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() { @Override public void onEvent(SseEvent event) { - if (isClosing) { + if (state.get() == TransportState.CLOSING || state.get() == TransportState.DISCONNECTED) { return; } - + sink.success(); try { - if (ENDPOINT_EVENT_TYPE.equals(event.type())) { - String endpoint = event.data(); - messageEndpoint.set(endpoint); - closeLatch.countDown(); - future.complete(null); - } - else if (MESSAGE_EVENT_TYPE.equals(event.type())) { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data()); - handler.apply(Mono.just(message)).subscribe(); - } - else { - logger.error("Received unrecognized SSE event type: {}", event.type()); + switch (event.type()) { + case ENDPOINT_EVENT_TYPE -> { + messageEndpoint.set(event.data()); + state.set(TransportState.CONNECTED); + } + case MESSAGE_EVENT_TYPE -> { + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data()); + handler.apply(Mono.just(message)).subscribe(); + } + default -> logger.error("Received unrecognized SSE event type: {}", event.type()); } } - catch (IOException e) { + catch (Exception e) { logger.error("Error processing SSE event", e); - future.completeExceptionally(e); + sink.error(new McpError("Error processing SSE event")); } } @Override public void onError(Throwable error) { - if (!isClosing) { + if (state.get() != TransportState.CLOSING) { logger.error("SSE connection error", error); - future.completeExceptionally(error); + sink.error(error); } } }); - - return Mono.fromFuture(future); } /** @@ -394,44 +387,44 @@ public void onError(Throwable error) { */ @Override public Mono sendMessage(JSONRPCMessage message) { - if (isClosing) { + if (state.get() == TransportState.CLOSING || state.get() == TransportState.DISCONNECTED) { return Mono.empty(); } - - try { - if (!closeLatch.await(10, TimeUnit.SECONDS)) { - return Mono.error(new McpError("Failed to wait for the message endpoint")); + return Mono.defer(() -> { + if (messageEndpoint.get() == null) { + return Mono.error(new McpError("No message endpoint available")); } - } - catch (InterruptedException e) { - return Mono.error(new McpError("Failed to wait for the message endpoint")); - } - String endpoint = messageEndpoint.get(); - if (endpoint == null) { - return Mono.error(new McpError("No message endpoint available")); - } + return serializeMessage(message).flatMap(body -> sendHttpPost(messageEndpoint.get(), body)) + .doOnNext(this::logIfNotOk) + .doOnError(err -> logger.error("Error sending message", err)) + .then(); + + }).retryWhen(Retry.fixedDelay(3, Duration.ofSeconds(3)).filter(err -> messageEndpoint.get() == null)); + } + private Mono serializeMessage(final JSONRPCMessage message) { try { - String jsonText = this.objectMapper.writeValueAsString(message); - URI requestUri = Utils.resolveUri(baseUri, endpoint); - HttpRequest request = this.requestBuilder.uri(requestUri) - .POST(HttpRequest.BodyPublishers.ofString(jsonText)) - .build(); - - return Mono.fromFuture( - httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).thenAccept(response -> { - if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 - && response.statusCode() != 206) { - logger.error("Error sending message: {}", response.statusCode()); - } - })); + return Mono.just(objectMapper.writeValueAsString(message)); } catch (IOException e) { - if (!isClosing) { - return Mono.error(new RuntimeException("Failed to serialize message", e)); - } - return Mono.empty(); + return Mono.error(new McpError("Failed to serialize message")); + } + } + + private Mono> sendHttpPost(final String endpoint, final String body) { + final URI requestUri = Utils.resolveUri(baseUri, endpoint); + final HttpRequest request = requestBuilder.uri(requestUri) + .POST(HttpRequest.BodyPublishers.ofString(body)) + .build(); + + return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding())); + } + + private void logIfNotOk(final HttpResponse response) { + if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 + && response.statusCode() != 206) { + logger.error("Error sending message: {}", response.statusCode()); } } @@ -445,12 +438,10 @@ public Mono sendMessage(JSONRPCMessage message) { */ @Override public Mono closeGracefully() { + state.set(TransportState.CLOSING); return Mono.fromRunnable(() -> { - isClosing = true; - CompletableFuture future = connectionFuture.get(); - if (future != null && !future.isDone()) { - future.cancel(true); - } + sseClient.close(); + state.set(TransportState.DISCONNECTED); }); } @@ -466,4 +457,19 @@ public T unmarshalFrom(Object data, TypeReference typeRef) { return this.objectMapper.convertValue(data, typeRef); } + /** + * Get the current transport state. + * @return the current transport state + */ + public TransportState getState() { + return state.get(); + } + + // Enum to manage transport states + public enum TransportState { + + DISCONNECTED, CONNECTING, CONNECTED, CLOSING + + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/SseConnectionException.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/SseConnectionException.java new file mode 100644 index 000000000..58fb4f216 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/SseConnectionException.java @@ -0,0 +1,41 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +/** + * Exception thrown when there is an issue with the SSE connection. + */ +public class SseConnectionException extends RuntimeException { + + private final int statusCode; + + /** + * Constructor for SseConnectionException. + * @param message the error message + * @param statusCode the HTTP status code associated with the error + */ + public SseConnectionException(final String message, final int statusCode) { + super(message + " (Status code: " + statusCode + ")"); + this.statusCode = statusCode; + } + + /** + * Gets the HTTP status code associated with this exception. + * @return the HTTP status code. + */ + public int getStatusCode() { + return statusCode; + } + + /** + * Checks if the status code indicates a retryable error. + * @return true if the status code is 408, 429, or in the 500-599 range; false + * otherwise. + */ + public boolean isRetryable() { + return statusCode == 408 || statusCode == 429 || (statusCode >= 500 && statusCode < 600); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index 762264de3..962073463 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -7,10 +7,8 @@ import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; -import java.net.http.HttpResponse; import java.time.Duration; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -22,8 +20,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; -import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; + import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; import reactor.core.publisher.Mono; @@ -34,9 +31,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.ObjectMapper; @@ -370,25 +364,4 @@ void testChainedCustomizations() { customizedTransport.closeGracefully().block(); } - @Test - @SuppressWarnings("unchecked") - void testResolvingClientEndpoint() { - HttpClient httpClient = Mockito.mock(HttpClient.class); - HttpResponse httpResponse = Mockito.mock(HttpResponse.class); - CompletableFuture> future = new CompletableFuture<>(); - future.complete(httpResponse); - when(httpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(future); - - HttpClientSseClientTransport transport = new HttpClientSseClientTransport(httpClient, HttpRequest.newBuilder(), - "http://example.com", "http://example.com/sse", new ObjectMapper()); - - transport.connect(Function.identity()); - - ArgumentCaptor httpRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); - verify(httpClient).sendAsync(httpRequestCaptor.capture(), any(HttpResponse.BodyHandler.class)); - assertThat(httpRequestCaptor.getValue().uri()).isEqualTo(URI.create("http://example.com/sse")); - - transport.closeGracefully().block(); - } - }