Skip to content

feat(ws): adds ws transport client #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
/*
* Copyright 2024 - 2024 the original author or authors.
*/

package io.modelcontextprotocol.client.transport;

import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.WebSocket;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;

import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.util.Assert;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import reactor.util.retry.Retry;

/**
* The WebSocket (WS) implementation of the
* {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with WS
* transport specification, using Java's HttpClient.
*
* @author Aliaksei Darafeyeu
*/
public class WebSocketClientTransport implements McpClientTransport {

private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketClientTransport.class);

private final HttpClient httpClient;

private final ObjectMapper objectMapper;

private final URI uri;

private final AtomicReference<WebSocket> webSocketRef = new AtomicReference<>();

private final AtomicReference<TransportState> state = new AtomicReference<>(TransportState.DISCONNECTED);

private final Sinks.Many<Throwable> errorSink = Sinks.many().multicast().onBackpressureBuffer();

/**
* The constructor for the WebSocketClientTransport.
* @param uri the URI to connect to
* @param clientBuilder the HttpClient builder
* @param objectMapper the ObjectMapper for JSON serialization/deserialization
*/
WebSocketClientTransport(final URI uri, final HttpClient.Builder clientBuilder, final ObjectMapper objectMapper) {
this.uri = uri;
this.httpClient = clientBuilder.build();
this.objectMapper = objectMapper;
}

/**
* Creates a new WebSocketClientTransport instance with the specified URI.
* @param uri the URI to connect to
* @return a new Builder instance
*/
public static Builder builder(final URI uri) {
return new Builder().uri(uri);
}

/**
* The state of the Transport connection.
*/
public enum TransportState {

DISCONNECTED, CONNECTING, CONNECTED, CLOSED

}

/**
* A builder for creating instances of WebSocketClientTransport.
*/
public static class Builder {

private URI uri;

private final HttpClient.Builder clientBuilder = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_1_1)
.connectTimeout(Duration.ofSeconds(10));

private ObjectMapper objectMapper = new ObjectMapper();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of object mappers defined in lots of files. It would seem like it might be time to build a Jackson module and lock in on one that can be used across the code base.


public Builder uri(final URI uri) {
this.uri = uri;
return this;
}

public Builder customizeClient(final Consumer<HttpClient.Builder> clientCustomizer) {
Assert.notNull(clientCustomizer, "clientCustomizer must not be null");
clientCustomizer.accept(clientBuilder);
return this;
}

public Builder objectMapper(final ObjectMapper objectMapper) {
this.objectMapper = objectMapper;
return this;
}

public WebSocketClientTransport build() {
return new WebSocketClientTransport(uri, clientBuilder, objectMapper);
}

}

public Mono<Void> connect(final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
if (!state.compareAndSet(TransportState.DISCONNECTED, TransportState.CONNECTING)) {
return Mono.error(new IllegalStateException("WebSocket is already connecting or connected"));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to point the actual state in the exception. I always hate error messages like "its either this or that", instead tell me what it was.

}

return Mono.fromFuture(httpClient.newWebSocketBuilder().buildAsync(uri, new WebSocket.Listener() {
private final StringBuilder messageBuffer = new StringBuilder();

@Override
public void onOpen(WebSocket webSocket) {
webSocketRef.set(webSocket);
state.set(TransportState.CONNECTED);
}

@Override
public CompletionStage<?> onText(WebSocket webSocket, CharSequence data, boolean last) {
messageBuffer.append(data);
if (last) {
final String fullMessage = messageBuffer.toString();
messageBuffer.setLength(0);
try {
final McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper,
fullMessage);
handler.apply(Mono.just(msg)).subscribe();
}
catch (Exception e) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed this in #156. It would be nice to merge that and then we can stop catch Exception use a more fine grained type.

errorSink.tryEmitNext(e);
LOGGER.error("Error processing WS event", e);
}
}

webSocket.request(1);
return CompletableFuture.completedFuture(null);
}

@Override
public void onError(WebSocket webSocket, Throwable error) {
errorSink.tryEmitNext(error);
state.set(TransportState.CLOSED);
LOGGER.error("WS connection error", error);
}

@Override
public CompletionStage<?> onClose(WebSocket webSocket, int statusCode, String reason) {
state.set(TransportState.CLOSED);
return CompletableFuture.completedFuture(null);
}

})).then();
}

@Override
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {

return Mono.defer(() -> {
WebSocket ws = webSocketRef.get();
if (ws == null && state.get() == TransportState.CONNECTING) {
return Mono.error(new IllegalStateException("WebSocket is connecting."));
}

if (ws == null || state.get() == TransportState.DISCONNECTED || state.get() == TransportState.CLOSED) {
return Mono.error(new IllegalStateException("WebSocket is closed."));
}

try {
String json = objectMapper.writeValueAsString(message);
return Mono.fromFuture(ws.sendText(json, true)).then();
}
catch (Exception e) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UncheckedIOException seems nice here.

return Mono.error(e);
}
}).retryWhen(Retry.backoff(3, Duration.ofSeconds(3)).filter(err -> {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Externalize or at least put in constants.

if (err instanceof IllegalStateException) {
return err.getMessage().equals("WebSocket is connecting.");
}
return true;
})).onErrorResume(e -> {
LOGGER.error("Failed to send message after retries", e);
errorSink.tryEmitNext(e);
return Mono.error(new IllegalStateException("WebSocket send failed after retries", e));
});

}

@Override
public Mono<Void> closeGracefully() {
WebSocket webSocket = webSocketRef.getAndSet(null);
if (webSocket != null && state.get() == TransportState.CONNECTED) {
state.set(TransportState.CLOSED);
return Mono.fromFuture(webSocket.sendClose(WebSocket.NORMAL_CLOSURE, "Closing")).then();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A constant here would be better possibly.

}
return Mono.empty();
}

@Override
public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
return objectMapper.convertValue(data, typeRef);
}

public TransportState getState() {
return state.get();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright 2024-2024 the original author or authors.
*/

package io.modelcontextprotocol.client.transport;

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.net.URI;
import java.util.List;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.images.builder.ImageFromDockerfile;

import io.modelcontextprotocol.spec.McpSchema;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

/**
* Tests for the {@link WebSocketClientTransport} class.
*
* @author Aliaksei Darafeyeu
*/
class WebSocketClientTransportTest {

private static GenericContainer<?> wsContainer;

private static URI websocketUri;

private WebSocketClientTransport transport;

@BeforeAll
static void startContainer() {
wsContainer = new GenericContainer<>(
new ImageFromDockerfile().withFileFromClasspath("server.js", "ws/server.js")
.withFileFromClasspath("Dockerfile", "ws/Dockerfile"))
.withExposedPorts(8080);

wsContainer.start();

int port = wsContainer.getMappedPort(8080);
websocketUri = URI.create("ws://localhost:" + port);
}

@BeforeEach
public void setUp() {
transport = WebSocketClientTransport.builder(websocketUri).build();
}

@AfterAll
static void tearDown() {
wsContainer.stop();
}

@Test
void testConnectSuccessfully() {
// Try to connect to the WebSocket server
Mono<Void> connection = transport.connect(message -> Mono.empty());

// Wait for the connection to complete
StepVerifier.create(connection).expectComplete().verify();

// Ensure that connection is established
assertEquals(WebSocketClientTransport.TransportState.CONNECTED, transport.getState());
}

@Test
void testSendMessage() {
// Connect to the server
Mono<Void> connection = transport.connect(message -> Mono.empty());

// Ensure connection is successful
StepVerifier.create(connection).expectComplete().verify();

// Create a simple message to send
var messageRequest = new McpSchema.CreateMessageRequest(
List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))),
null, null, null, null, 0, null, null);
McpSchema.JSONRPCMessage message = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION,
McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, "test-id", messageRequest);

// Send a message to the server
Mono<Void> sendMessage = transport.sendMessage(message);

// Ensure message is sent successfully
StepVerifier.create(sendMessage).expectComplete().verify();
}

@Test
void testCloseConnectionGracefully() {
Mono<Void> connection = transport.connect(message -> Mono.empty());

StepVerifier.create(connection).expectComplete().verify();

// Close the connection gracefully
Mono<Void> closeConnection = transport.closeGracefully();

// Verify that the connection is closed successfully
StepVerifier.create(closeConnection).expectComplete().verify();

assertEquals(WebSocketClientTransport.TransportState.CLOSED, transport.getState());
}

@Test
void testSendMessageAfterConnectionClosed() {
// Send a message before connection is established
// Create a simple message to send
var messageRequest = new McpSchema.CreateMessageRequest(
List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))),
null, null, null, null, 0, null, null);
McpSchema.JSONRPCMessage message = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION,
McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, "test-id", messageRequest);

Mono<Void> sendMessageBeforeConnect = transport.sendMessage(message);

// Verify that the transport returns an error because the connection is closed
StepVerifier.create(sendMessageBeforeConnect).expectError(IllegalStateException.class).verify();
}

}
17 changes: 17 additions & 0 deletions mcp/src/test/resources/ws/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Use a Node.js base image
FROM node:14

# Set the working directory inside the container
WORKDIR /usr/src/app

# Copy the server.js file into the container
COPY server.js /usr/src/app/

# Install dependencies (e.g., the ws package)
RUN npm init -y && npm install ws

# Expose the port for WebSocket (e.g., 8080)
EXPOSE 8080

# Command to run the WebSocket server
CMD ["node", "server.js"]
21 changes: 21 additions & 0 deletions mcp/src/test/resources/ws/server.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Import the WebSocket package
const WebSocket = require('ws');

// Set up the WebSocket server to listen on port 8080
const wss = new WebSocket.Server({ port: 8080 });

// When a new WebSocket connection is established
wss.on('connection', function connection(ws) {
console.log('New client connected');

// When a message is received from the client
ws.on('message', function incoming(message) {
console.log('received: %s', message);
});

// Send a welcome message to the client
ws.send('Welcome to the WebSocket server!');
});

// Log the WebSocket server start
console.log('WebSocket server is listening on ws://localhost:8080');
Loading