Skip to content

Commit fcd0d8f

Browse files
committed
refactor: correct retry
1 parent fa101e7 commit fcd0d8f

File tree

3 files changed

+148
-107
lines changed

3 files changed

+148
-107
lines changed

mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java

+106-103
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
import java.net.http.HttpClient;
88
import java.net.http.HttpRequest;
99
import java.net.http.HttpResponse;
10+
import java.time.Duration;
1011
import java.util.concurrent.Flow;
1112
import java.util.concurrent.atomic.AtomicReference;
1213
import java.util.function.Function;
1314
import java.util.regex.Pattern;
1415

1516
import reactor.core.publisher.Mono;
17+
import reactor.util.retry.Retry;
1618

1719
/**
1820
* A Server-Sent Events (SSE) client implementation using Java's Flow API for reactive
@@ -65,6 +67,12 @@ public class FlowSseClient {
6567
*/
6668
private final AtomicReference<Flow.Subscription> currentSubscription = new AtomicReference<>();
6769

70+
/**
71+
* Atomic reference to hold the last event ID received from the SSE stream. This can
72+
* be used to resume the stream from the last known event.
73+
*/
74+
private final AtomicReference<String> lastEventId = new AtomicReference<>();
75+
6876
/**
6977
* Record class representing a Server-Sent Event with its standard fields.
7078
*
@@ -145,125 +153,120 @@ public void subscribe(String url, SseEventHandler eventHandler) {
145153
* @throws RuntimeException if the connection fails with a non-200 status code
146154
*/
147155
public Mono<Void> subscribeAsync(String url, SseEventHandler eventHandler) {
148-
HttpRequest request = this.requestBuilder.uri(URI.create(url))
149-
.header("Accept", "text/event-stream")
150-
.header("Cache-Control", "no-cache")
151-
.GET()
152-
.build();
153-
154-
SseSubscriber lineSubscriber = new SseSubscriber(eventHandler);
155-
Function<Flow.Subscriber<String>, HttpResponse.BodySubscriber<Void>> subscriberFactory = HttpResponse.BodySubscribers::fromLineSubscriber;
156+
final Function<Flow.Subscriber<String>, HttpResponse.BodySubscriber<Void>> subscriberFactory = HttpResponse.BodySubscribers::fromLineSubscriber;
157+
final StringBuilder eventBuilder = new StringBuilder();
158+
final AtomicReference<String> currentEventId = new AtomicReference<>();
159+
final AtomicReference<String> currentEventType = new AtomicReference<>("message");
160+
final Flow.Subscriber<String> lineSubscriber = new Flow.Subscriber<>() {
161+
private Flow.Subscription subscription;
162+
163+
@Override
164+
public void onSubscribe(Flow.Subscription subscription) {
165+
this.subscription = subscription;
166+
currentSubscription.set(subscription);
167+
subscription.request(Long.MAX_VALUE);
168+
}
156169

157-
return Mono
158-
.fromFuture(() -> this.httpClient.sendAsync(request, info -> subscriberFactory.apply(lineSubscriber)))
159-
.doOnTerminate(lineSubscriber::cancelSubscription)
160-
.doOnError(eventHandler::onError)
161-
.doOnSuccess(response -> {
162-
int status = response.statusCode();
163-
if (status != 200 && status != 201 && status != 202 && status != 206) {
164-
throw new RuntimeException("Failed to connect to SSE stream. Unexpected status code: " + status);
170+
@Override
171+
public void onNext(String line) {
172+
if (line.isEmpty()) {
173+
// Empty line means end of event
174+
if (eventBuilder.length() > 0) {
175+
String eventData = eventBuilder.toString();
176+
SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim());
177+
lastEventId.set(currentEventId.get());
178+
eventHandler.onEvent(event);
179+
eventBuilder.setLength(0);
180+
}
165181
}
166-
})
167-
.then()
168-
.doOnSubscribe(subscription -> currentSubscription.set(lineSubscriber.getSubscription()));
169-
}
170-
171-
/**
172-
* Gracefully close the SSE stream subscription if active.
173-
*/
174-
public void close() {
175-
Flow.Subscription subscription = currentSubscription.get();
176-
if (subscription != null) {
177-
subscription.cancel();
178-
currentSubscription.set(null);
179-
}
180-
}
181-
182-
/**
183-
* Inner class that implements Flow.Subscriber to handle incoming SSE events.
184-
* It processes the event stream, parsing the data and notifying the event handler.
185-
*/
186-
private static class SseSubscriber implements Flow.Subscriber<String> {
187-
188-
private final SseEventHandler eventHandler;
189-
190-
private final StringBuilder eventBuilder = new StringBuilder();
191-
192-
private final AtomicReference<String> currentEventId = new AtomicReference<>();
193-
194-
private final AtomicReference<String> currentEventType = new AtomicReference<>("message");
195-
196-
private Flow.Subscription subscription;
197-
198-
public SseSubscriber(SseEventHandler eventHandler) {
199-
this.eventHandler = eventHandler;
200-
}
182+
else {
183+
if (line.startsWith("data:")) {
184+
var matcher = EVENT_DATA_PATTERN.matcher(line);
185+
if (matcher.find()) {
186+
eventBuilder.append(matcher.group(1).trim()).append("\n");
187+
}
188+
}
189+
else if (line.startsWith("id:")) {
190+
var matcher = EVENT_ID_PATTERN.matcher(line);
191+
if (matcher.find()) {
192+
currentEventId.set(matcher.group(1).trim());
193+
}
194+
}
195+
else if (line.startsWith("event:")) {
196+
var matcher = EVENT_TYPE_PATTERN.matcher(line);
197+
if (matcher.find()) {
198+
currentEventType.set(matcher.group(1).trim());
199+
}
200+
}
201+
}
202+
subscription.request(1);
203+
}
201204

202-
@Override
203-
public void onSubscribe(Flow.Subscription subscription) {
204-
this.subscription = subscription;
205-
subscription.request(Long.MAX_VALUE);
206-
}
205+
@Override
206+
public void onError(Throwable throwable) {
207+
eventHandler.onError(throwable);
208+
}
207209

208-
@Override
209-
public void onNext(String line) {
210-
if (line.isEmpty()) {
211-
// Empty line means end of event
212-
if (eventBuilder.isEmpty()) {
210+
@Override
211+
public void onComplete() {
212+
// Handle any remaining event data
213+
if (eventBuilder.length() > 0) {
213214
String eventData = eventBuilder.toString();
214215
SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim());
215216
eventHandler.onEvent(event);
216-
eventBuilder.setLength(0);
217-
}
218-
}
219-
else {
220-
if (line.startsWith("data:")) {
221-
var matcher = EVENT_DATA_PATTERN.matcher(line);
222-
if (matcher.find()) {
223-
eventBuilder.append(matcher.group(1).trim()).append("\n");
224-
}
225-
}
226-
else if (line.startsWith("id:")) {
227-
var matcher = EVENT_ID_PATTERN.matcher(line);
228-
if (matcher.find()) {
229-
currentEventId.set(matcher.group(1).trim());
230-
}
231-
}
232-
else if (line.startsWith("event:")) {
233-
var matcher = EVENT_TYPE_PATTERN.matcher(line);
234-
if (matcher.find()) {
235-
currentEventType.set(matcher.group(1).trim());
236-
}
237217
}
238218
}
239-
subscription.request(1);
240-
}
219+
};
241220

242-
@Override
243-
public void onError(Throwable throwable) {
244-
eventHandler.onError(throwable);
245-
}
221+
return Mono.defer(() -> {
222+
HttpRequest.Builder builder = this.requestBuilder.uri(URI.create(url))
223+
.header("Accept", "text/event-stream")
224+
.header("Cache-Control", "no-cache")
225+
.GET();
246226

247-
@Override
248-
public void onComplete() {
249-
// Handle any remaining event data
250-
if (eventBuilder.isEmpty()) {
251-
String eventData = eventBuilder.toString();
252-
SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim());
253-
eventHandler.onEvent(event);
227+
String lastId = lastEventId.get();
228+
if (lastId != null) {
229+
builder.header("Last-Event-ID", lastId);
254230
}
255-
}
256231

257-
public Flow.Subscription getSubscription() {
258-
return this.subscription;
259-
}
232+
HttpRequest request = builder.build();
260233

261-
public void cancelSubscription() {
262-
if (subscription != null) {
263-
subscription.cancel();
234+
return Mono
235+
.fromFuture(() -> this.httpClient.sendAsync(request, info -> subscriberFactory.apply(lineSubscriber)))
236+
.flatMap(response -> {
237+
int status = response.statusCode();
238+
if (status >= 400 && status < 500 && status != 429 && status != 408) {
239+
return Mono.error(new SseConnectionException("Client error." + status, status));
240+
}
241+
if (status != 200 && status != 201 && status != 202 && status != 206) {
242+
return Mono.error(new SseConnectionException("Failed to connect to SSE stream.", status));
243+
}
244+
return Mono.empty();
245+
})
246+
.doOnError(eventHandler::onError)
247+
.doFinally(sig -> {
248+
Flow.Subscription active = currentSubscription.getAndSet(null);
249+
if (active != null)
250+
active.cancel();
251+
})
252+
.then();
253+
}).retryWhen(Retry.backoff(3, Duration.ofSeconds(2)).filter(err -> {
254+
if (err instanceof SseConnectionException exception) {
255+
return exception.isRetryable();
264256
}
265-
}
257+
return true; // Retry on other exceptions
258+
}).onRetryExhaustedThrow((spec, signal) -> signal.failure()));
259+
260+
}
266261

262+
/**
263+
* Gracefully close the SSE stream subscription if active.
264+
*/
265+
public void close() {
266+
Flow.Subscription subscription = currentSubscription.getAndSet(null);
267+
if (subscription != null) {
268+
subscription.cancel();
269+
}
267270
}
268271

269272
}

mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java

+1-4
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,6 @@ public HttpClientSseClientTransport build() {
332332
public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> handler) {
333333
state.set(TransportState.CONNECTING);
334334
return Mono.<Void>create(sink -> subscribeSse(handler, sink))
335-
.timeout(Duration.ofSeconds(10))
336-
.retryWhen(Retry.backoff(3, Duration.ofSeconds(1)).maxBackoff(Duration.ofSeconds(5)))
337335
.doOnError(err -> logger.error("Error during connection", err));
338336

339337
}
@@ -345,13 +343,12 @@ public void onEvent(SseEvent event) {
345343
if (state.get() == TransportState.CLOSING || state.get() == TransportState.DISCONNECTED) {
346344
return;
347345
}
348-
346+
sink.success();
349347
try {
350348
switch (event.type()) {
351349
case ENDPOINT_EVENT_TYPE -> {
352350
messageEndpoint.set(event.data());
353351
state.set(TransportState.CONNECTED);
354-
sink.success();
355352
}
356353
case MESSAGE_EVENT_TYPE -> {
357354
JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data());
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.client.transport;
6+
7+
/**
8+
* Exception thrown when there is an issue with the SSE connection.
9+
*/
10+
public class SseConnectionException extends RuntimeException {
11+
12+
private final int statusCode;
13+
14+
/**
15+
* Constructor for SseConnectionException.
16+
* @param message the error message
17+
* @param statusCode the HTTP status code associated with the error
18+
*/
19+
public SseConnectionException(final String message, final int statusCode) {
20+
super(message + " (Status code: " + statusCode + ")");
21+
this.statusCode = statusCode;
22+
}
23+
24+
/**
25+
* Gets the HTTP status code associated with this exception.
26+
* @return the HTTP status code.
27+
*/
28+
public int getStatusCode() {
29+
return statusCode;
30+
}
31+
32+
/**
33+
* Checks if the status code indicates a retryable error.
34+
* @return true if the status code is 408, 429, or in the 500-599 range; false
35+
* otherwise.
36+
*/
37+
public boolean isRetryable() {
38+
return statusCode == 408 || statusCode == 429 || (statusCode >= 500 && statusCode < 600);
39+
}
40+
41+
}

0 commit comments

Comments
 (0)