Skip to content

Commit c6d71ce

Browse files
committed
refactor: migrates CompletableFuture to reactive patterns for HttpClientSseClientTransport
1 parent 84adde1 commit c6d71ce

File tree

2 files changed

+201
-131
lines changed

2 files changed

+201
-131
lines changed

mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java

Lines changed: 126 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
import java.net.http.HttpClient;
88
import java.net.http.HttpRequest;
99
import java.net.http.HttpResponse;
10-
import java.util.concurrent.CompletableFuture;
10+
import java.util.Map;
1111
import java.util.concurrent.Flow;
1212
import java.util.concurrent.atomic.AtomicReference;
13+
import java.util.function.BiConsumer;
1314
import java.util.function.Function;
1415
import java.util.regex.Pattern;
1516

17+
import reactor.core.publisher.Mono;
18+
1619
/**
1720
* A Server-Sent Events (SSE) client implementation using Java's Flow API for reactive
1821
* stream processing. This client establishes a connection to an SSE endpoint and
@@ -59,14 +62,19 @@ public class FlowSseClient {
5962
*/
6063
private static final Pattern EVENT_TYPE_PATTERN = Pattern.compile("^event:(.+)$", Pattern.MULTILINE);
6164

65+
/**
66+
* Atomic reference to hold the current subscription for the SSE stream.
67+
*/
68+
private final AtomicReference<Flow.Subscription> currentSubscription = new AtomicReference<>();
69+
6270
/**
6371
* Record class representing a Server-Sent Event with its standard fields.
6472
*
6573
* @param id the event ID (may be null)
6674
* @param type the event type (defaults to "message" if not specified in the stream)
6775
* @param data the event payload data
6876
*/
69-
public static record SseEvent(String id, String type, String data) {
77+
public record SseEvent(String id, String type, String data) {
7078
}
7179

7280
/**
@@ -121,90 +129,143 @@ public FlowSseClient(HttpClient httpClient, HttpRequest.Builder requestBuilder)
121129
* @throws RuntimeException if the connection fails with a non-200 status code
122130
*/
123131
public void subscribe(String url, SseEventHandler eventHandler) {
132+
subscribeAsync(url, eventHandler).subscribe();
133+
}
134+
135+
/**
136+
* Subscribes to an SSE endpoint and processes the event stream.
137+
*
138+
* <p>
139+
* This method establishes a connection to the specified URL and begins processing the
140+
* SSE stream. Events are parsed and delivered to the provided event handler. The
141+
* connection remains active until either an error occurs or the server closes the
142+
* connection.
143+
* @param url the SSE endpoint URL to connect to
144+
* @param eventHandler the handler that will receive SSE events and error
145+
* notifications
146+
* @return a Mono representing the completion of the subscription
147+
* @throws RuntimeException if the connection fails with a non-200 status code
148+
*/
149+
public Mono<Void> subscribeAsync(String url, SseEventHandler eventHandler) {
124150
HttpRequest request = this.requestBuilder.uri(URI.create(url))
125151
.header("Accept", "text/event-stream")
126152
.header("Cache-Control", "no-cache")
127153
.GET()
128154
.build();
129155

130-
StringBuilder eventBuilder = new StringBuilder();
131-
AtomicReference<String> currentEventId = new AtomicReference<>();
132-
AtomicReference<String> currentEventType = new AtomicReference<>("message");
156+
SseSubscriber lineSubscriber = new SseSubscriber(eventHandler);
157+
Function<Flow.Subscriber<String>, HttpResponse.BodySubscriber<Void>> subscriberFactory = HttpResponse.BodySubscribers::fromLineSubscriber;
133158

134-
Flow.Subscriber<String> lineSubscriber = new Flow.Subscriber<>() {
135-
private Flow.Subscription subscription;
159+
return Mono
160+
.fromFuture(() -> this.httpClient.sendAsync(request, info -> subscriberFactory.apply(lineSubscriber)))
161+
.doOnTerminate(lineSubscriber::cancelSubscription)
162+
.doOnError(eventHandler::onError)
163+
.doOnSuccess(response -> {
164+
int status = response.statusCode();
165+
if (status != 200 && status != 201 && status != 202 && status != 206) {
166+
throw new RuntimeException("Failed to connect to SSE stream. Unexpected status code: " + status);
167+
}
168+
})
169+
.then()
170+
.doOnSubscribe(subscription -> currentSubscription.set(lineSubscriber.getSubscription()));
171+
}
136172

137-
@Override
138-
public void onSubscribe(Flow.Subscription subscription) {
139-
this.subscription = subscription;
140-
subscription.request(Long.MAX_VALUE);
141-
}
173+
/**
174+
* Gracefully close the SSE stream subscription if active.
175+
*/
176+
public void close() {
177+
Flow.Subscription subscription = currentSubscription.get();
178+
if (subscription != null) {
179+
subscription.cancel();
180+
currentSubscription.set(null);
181+
}
182+
}
142183

143-
@Override
144-
public void onNext(String line) {
145-
if (line.isEmpty()) {
146-
// Empty line means end of event
147-
if (eventBuilder.length() > 0) {
148-
String eventData = eventBuilder.toString();
149-
SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim());
150-
eventHandler.onEvent(event);
151-
eventBuilder.setLength(0);
152-
}
184+
/**
185+
* Inner class that implements Flow.Subscriber to handle incoming SSE events.
186+
* It processes the event stream, parsing the data and notifying the event handler.
187+
*/
188+
private static class SseSubscriber implements Flow.Subscriber<String> {
189+
190+
private final SseEventHandler eventHandler;
191+
192+
private final StringBuilder eventBuilder = new StringBuilder();
193+
194+
private final AtomicReference<String> currentEventId = new AtomicReference<>();
195+
196+
private final AtomicReference<String> currentEventType = new AtomicReference<>("message");
197+
198+
private Flow.Subscription subscription;
199+
200+
public SseSubscriber(SseEventHandler eventHandler) {
201+
this.eventHandler = eventHandler;
202+
}
203+
204+
@Override
205+
public void onSubscribe(Flow.Subscription subscription) {
206+
this.subscription = subscription;
207+
subscription.request(Long.MAX_VALUE);
208+
}
209+
210+
@Override
211+
public void onNext(String line) {
212+
if (line.isEmpty()) {
213+
// Empty line means end of event
214+
if (eventBuilder.isEmpty()) {
215+
String eventData = eventBuilder.toString();
216+
SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim());
217+
eventHandler.onEvent(event);
218+
eventBuilder.setLength(0);
153219
}
154-
else {
155-
if (line.startsWith("data:")) {
156-
var matcher = EVENT_DATA_PATTERN.matcher(line);
157-
if (matcher.find()) {
158-
eventBuilder.append(matcher.group(1).trim()).append("\n");
159-
}
220+
}
221+
else {
222+
if (line.startsWith("data:")) {
223+
var matcher = EVENT_DATA_PATTERN.matcher(line);
224+
if (matcher.find()) {
225+
eventBuilder.append(matcher.group(1).trim()).append("\n");
160226
}
161-
else if (line.startsWith("id:")) {
162-
var matcher = EVENT_ID_PATTERN.matcher(line);
163-
if (matcher.find()) {
164-
currentEventId.set(matcher.group(1).trim());
165-
}
227+
}
228+
else if (line.startsWith("id:")) {
229+
var matcher = EVENT_ID_PATTERN.matcher(line);
230+
if (matcher.find()) {
231+
currentEventId.set(matcher.group(1).trim());
166232
}
167-
else if (line.startsWith("event:")) {
168-
var matcher = EVENT_TYPE_PATTERN.matcher(line);
169-
if (matcher.find()) {
170-
currentEventType.set(matcher.group(1).trim());
171-
}
233+
}
234+
else if (line.startsWith("event:")) {
235+
var matcher = EVENT_TYPE_PATTERN.matcher(line);
236+
if (matcher.find()) {
237+
currentEventType.set(matcher.group(1).trim());
172238
}
173239
}
174-
subscription.request(1);
175240
}
241+
subscription.request(1);
242+
}
176243

177-
@Override
178-
public void onError(Throwable throwable) {
179-
eventHandler.onError(throwable);
180-
}
244+
@Override
245+
public void onError(Throwable throwable) {
246+
eventHandler.onError(throwable);
247+
}
181248

182-
@Override
183-
public void onComplete() {
184-
// Handle any remaining event data
185-
if (eventBuilder.length() > 0) {
186-
String eventData = eventBuilder.toString();
187-
SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim());
188-
eventHandler.onEvent(event);
189-
}
249+
@Override
250+
public void onComplete() {
251+
// Handle any remaining event data
252+
if (eventBuilder.isEmpty()) {
253+
String eventData = eventBuilder.toString();
254+
SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim());
255+
eventHandler.onEvent(event);
190256
}
191-
};
257+
}
192258

193-
Function<Flow.Subscriber<String>, HttpResponse.BodySubscriber<Void>> subscriberFactory = subscriber -> HttpResponse.BodySubscribers
194-
.fromLineSubscriber(subscriber);
259+
public Flow.Subscription getSubscription() {
260+
return this.subscription;
261+
}
195262

196-
CompletableFuture<HttpResponse<Void>> future = this.httpClient.sendAsync(request,
197-
info -> subscriberFactory.apply(lineSubscriber));
198-
199-
future.thenAccept(response -> {
200-
int status = response.statusCode();
201-
if (status != 200 && status != 201 && status != 202 && status != 206) {
202-
throw new RuntimeException("Failed to connect to SSE stream. Unexpected status code: " + status);
263+
public void cancelSubscription() {
264+
if (subscription != null) {
265+
subscription.cancel();
203266
}
204-
}).exceptionally(throwable -> {
205-
eventHandler.onError(throwable);
206-
return null;
207-
});
267+
}
268+
208269
}
209270

210271
}

0 commit comments

Comments
 (0)