Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,30 @@ jobs:

- name: Run tests
working-directory: typescript-sdk
run: pnpm run test
run: pnpm run test

java:
name: Java SDK Tests
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up JDK
uses: actions/setup-java@v4
with:
java-version: '18'
distribution: 'temurin'

- name: Cache Maven dependencies
uses: actions/cache@v4
with:
path: ~/.m2
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
restore-keys: |
${{ runner.os }}-m2-

- name: Run tests
working-directory: java-sdk
run: mvn test
91 changes: 91 additions & 0 deletions java-sdk/integrations/spring-ai/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>com.ag-ui</groupId>
<artifactId>ag-ui</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../../pom.xml</relativePath>
</parent>

<artifactId>spring-ai</artifactId>

<properties>
<maven.compiler.source>21</maven.compiler.source>
<maven.compiler.target>21</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<spring.boot.version>3.2.0</spring.boot.version>
<spring-ai.version>1.0.0</spring-ai.version>
</properties>

<dependencyManagement>
<dependencies>
<!-- Import Spring Boot BOM for dependency management -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-dependencies</artifactId>
<version>${spring.boot.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-bom</artifactId>
<version>${spring-ai.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>

<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-model</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-ollama</artifactId>
<version>1.0.0</version>
</dependency>


<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.ag-ui</groupId>
<artifactId>client</artifactId>
<version>0.0.1-SNAPSHOT</version>
<scope>compile</scope>
</dependency>
</dependencies>


<build>
<plugins>
<!-- Spring Boot Maven Plugin -->
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<version>${spring.boot.version}</version>
<executions>
<execution>
<goals>
<goal>repackage</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package com.agui;

import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.Ordered;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.web.cors.UrlBasedCorsConfigurationSource;
import org.springframework.web.filter.CorsFilter;

import java.util.Arrays;

@Configuration
public class CorsConfig {

@Bean
public CorsConfigurationSource corsConfigurationSource() {
CorsConfiguration configuration = new CorsConfiguration();
configuration.setAllowedOriginPatterns(Arrays.asList("*")); // Or specify domains
configuration.setAllowedMethods(Arrays.asList("GET", "POST", "PUT", "DELETE", "OPTIONS"));
configuration.setAllowedHeaders(Arrays.asList("*"));

UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource();
source.registerCorsConfiguration("/**", configuration);
return source;
}

@Bean
public FilterRegistrationBean<CorsFilter> corsFilter() {
FilterRegistrationBean<CorsFilter> bean = new FilterRegistrationBean<>(
new CorsFilter(corsConfigurationSource())
);
bean.setOrder(Ordered.HIGHEST_PRECEDENCE);
return bean;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.agui;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class MainApplication {

public static void main(String[] args) {
SpringApplication.run(MainApplication.class, args);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package com.agui.spring;

import com.agui.client.RunAgentParameters;
import com.agui.client.subscriber.AgentSubscriber;
import com.agui.client.subscriber.AgentSubscriberParams;
import com.agui.event.BaseEvent;
import com.agui.types.State;
import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.http.CacheControl;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import java.io.IOException;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;

@RestController
public class AgUiController {

@PostMapping(value = "/sse/{agentId}")
public ResponseEntity<SseEmitter> streamData(@PathVariable("agentId") final String agentId, @RequestBody() final AgUiParameters agUiParameters) {
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);

var chatModel = OllamaChatModel.builder()
.defaultOptions(OllamaOptions.builder().model("llama3.2").build())
.ollamaApi(OllamaApi.builder().baseUrl("http://localhost:11434").build())
.build();

SpringAgent agent = new SpringAgent(
agentId,
"description",
Objects.nonNull(agUiParameters.getThreadId()) ? agUiParameters.getThreadId() : UUID.randomUUID().toString(),
agUiParameters.getMessages().stream().map(m -> {
if (Objects.isNull(m.getName())) {
m.setName("");
}
return m;
}).toList(),
chatModel,
new State(),
true
);

var parameters = RunAgentParameters.builder()
.runId(UUID.randomUUID().toString())
.context(agUiParameters.getContext())
.forwardedProps(agUiParameters.getForwardedProps())
.tools(agUiParameters.getTools())
.build();

var objectMapper = new ObjectMapper();

agent.runAgent(parameters, new AgentSubscriber() {
@Override
public void onEvent(BaseEvent event) {
try {
emitter.send(SseEmitter.event().data(" " + objectMapper.writeValueAsString(event)).build());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public void onRunFinalized(AgentSubscriberParams params) {
emitter.complete();
}
@Override
public void onRunFailed(AgentSubscriberParams params, Throwable throwable) {
emitter.completeWithError(throwable);
}
});

return ResponseEntity
.ok()
.cacheControl(CacheControl.noCache())
.body(emitter);
}

@GetMapping(value = "/{agentId}", produces = MediaType.TEXT_PLAIN_VALUE)
public ResponseBodyEmitter streamData(
@PathVariable("agentId") final String agentId,
HttpServletResponse response
) {
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Connection", "keep-alive");
response.setContentType("text/plain;charset=UTF-8");

ResponseBodyEmitter emitter = new ResponseBodyEmitter();

// Process data in a separate thread
CompletableFuture.runAsync(() -> {
try {
for (int i = 0; i < 10; i++) {
emitter.send("Data chunk " + i + "\n");
Thread.sleep(1000); // Simulate processing delay
}
emitter.complete();
} catch (Exception e) {
emitter.completeWithError(e);
}
});

return emitter;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.agui.spring;

import com.agui.message.BaseMessage;
import com.agui.types.Context;
import com.agui.types.Tool;

import java.util.List;

public class AgUiParameters {

private String threadId;
private List<Tool> tools;
private List<Context> context;
private Object forwardedProps;
private List<BaseMessage> messages;

public void setThreadId(final String threadId) {
this.threadId = threadId;
}

public String getThreadId() {
return this.threadId;
}

public void setTools(final List<Tool> tools) {
this.tools = tools;
}

public List<Tool> getTools() {
return tools;
}

public void setContext(final List<Context> context) {
this.context = context;
}

public List<Context> getContext() {
return this.context;
}

public void setForwardedProps(final Object forwardedProps) {
this.forwardedProps = forwardedProps;
}

public Object getForwardedProps() {
return this.forwardedProps;
}

public void setMessages(final List<BaseMessage> messages) {
this.messages = messages;
}

public List<BaseMessage> getMessages() {
return this.messages;
}
}
Loading