diff --git a/spring-ai-rag/pom.xml b/spring-ai-rag/pom.xml index 01c8233ad31..b4a83dc8e1e 100644 --- a/spring-ai-rag/pom.xml +++ b/spring-ai-rag/pom.xml @@ -66,6 +66,11 @@ jackson-module-kotlin test + + + org.springframework.boot + spring-boot-starter-webflux + diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereApi.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereApi.java new file mode 100644 index 00000000000..6291db91935 --- /dev/null +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereApi.java @@ -0,0 +1,38 @@ +package org.springframework.ai.rag.postretrieval.rerank; + +/** + * Represents the API key holder for Cohere API authentication. + * + * @author KoreaNirsa + */ +public class CohereApi { + + private String apiKey; + + public static Builder builder() { + return new Builder(); + } + + public String getApiKey() { + return apiKey; + } + + public static class Builder { + + private final CohereApi instance = new CohereApi(); + + public Builder apiKey(String key) { + instance.apiKey = key; + return this; + } + + public CohereApi build() { + if (instance.apiKey == null || instance.apiKey.isBlank()) { + throw new IllegalArgumentException("API key must be provided."); + } + return instance; + } + + } + +} diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereReranker.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereReranker.java new file mode 100644 index 00000000000..352cc7ac2fe --- /dev/null +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereReranker.java @@ -0,0 +1,116 @@ +package org.springframework.ai.rag.postretrieval.rerank; + +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.http.HttpHeaders; + +/** + * A Reranker implementation that integrates with Cohere's Rerank API. This component + * reorders retrieved documents based on semantic relevance to the input query. + * + * @author KoreaNirsa + * @see Cohere Rerank API + * Documentation + */ +public class CohereReranker { + + private static final String COHERE_RERANK_ENDPOINT = "https://api.cohere.ai/v1/rerank"; + + private static final Logger logger = LoggerFactory.getLogger(CohereReranker.class); + + private static final int MAX_DOCUMENTS = 1000; + + private final WebClient webClient; + + /** + * Constructs a CohereReranker that communicates with the Cohere Rerank API. + * Initializes the internal WebClient with the provided API key for authorization. + * @param cohereApi the API configuration object containing the required API key (must + * not be null) + * @throws IllegalArgumentException if cohereApi is null + */ + CohereReranker(CohereApi cohereApi) { + if (cohereApi == null) { + throw new IllegalArgumentException("CohereApi must not be null"); + } + + this.webClient = WebClient.builder() + .baseUrl(COHERE_RERANK_ENDPOINT) + .defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + cohereApi.getApiKey()) + .build(); + } + + /** + * Reranks a list of documents based on the provided query using the Cohere API. + * @param query The user input query. + * @param documents The list of documents to rerank. + * @param topN The number of top results to return (at most). + * @return A reranked list of documents. If the API fails, returns the original list. + */ + public List rerank(String query, List documents, int topN) { + if (topN < 1) { + throw new IllegalArgumentException("topN must be ≥ 1. Provided: " + topN); + } + + if (documents == null || documents.isEmpty()) { + logger.warn("Empty document list provided. Skipping rerank."); + return Collections.emptyList(); + } + + if (documents.size() > MAX_DOCUMENTS) { + logger.warn("Cohere recommends ≤ {} documents per rerank request. Larger sizes may cause errors.", + MAX_DOCUMENTS); + return documents; + } + + int adjustedTopN = Math.min(topN, documents.size()); + + Map payload = Map.of("query", query, "documents", + documents.stream().map(Document::getText).toList(), "top_n", adjustedTopN); + + // Call the API and process the result + return sendRerankRequest(payload).map(results -> results.stream() + .sorted(Comparator.comparingDouble(RerankResponse.Result::getRelevanceScore).reversed()) + .map(r -> { + Document original = documents.get(r.getIndex()); + Map metadata = new HashMap<>(original.getMetadata()); + metadata.put("score", String.format("%.4f", r.getRelevanceScore())); + return new Document(original.getText(), metadata); + }) + .toList()).orElseGet(() -> { + logger.warn("Cohere response is null or invalid"); + return documents; + }); + } + + /** + * Sends a rerank request to the Cohere API and returns the result list. + * @param payload The request body including query, documents, and top_n. + * @return An Optional list of reranked results, or empty if failed. + */ + private Optional> sendRerankRequest(Map payload) { + try { + RerankResponse response = webClient.post() + .bodyValue(payload) + .retrieve() + .bodyToMono(RerankResponse.class) + .block(); + + return Optional.ofNullable(response).map(RerankResponse::getResults); + } + catch (Exception e) { + logger.error("Cohere rerank failed, fallback to original order: {}", e.getMessage(), e); + return Optional.empty(); + } + } + +} diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankConfig.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankConfig.java new file mode 100644 index 00000000000..6849a61be18 --- /dev/null +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankConfig.java @@ -0,0 +1,57 @@ +package org.springframework.ai.rag.postretrieval.rerank; + +import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +/** + * Rerank configuration that conditionally registers a DocumentPostProcessor when rerank + * is enabled via application properties. + * + * This configuration is activated only when the following properties are set + * + *
    + *
  • spring.ai.rerank.enabled=true
  • + *
  • spring.ai.rerank.cohere.api-key=your-api-key
  • + *
+ * + * @author KoreaNirsa + */ +@Configuration +public class RerankConfig { + + @Value("${spring.ai.rerank.cohere.api-key}") + private String apiKey; + + /** + * Registers a DocumentPostProcessor bean that enables reranking using Cohere. + * + * This bean is only created when the property `spring.ai.rerank.enabled=true` is set. + * The API key is injected from application properties or environment variables. + * @return An instance of RerankerPostProcessor backed by Cohere API + */ + @Bean + @ConditionalOnProperty(name = "spring.ai.rerank.enabled", havingValue = "true") + public DocumentPostProcessor rerankerPostProcessor() { + return new RerankerPostProcessor(CohereApi.builder().apiKey(apiKey).build()); + } + + /** + * Provides a fallback DocumentPostProcessor when reranking is disabled or no custom + * implementation is registered. + * + * This implementation performs no reranking and simply returns the original list of + * documents. If additional post-processing is required, a custom bean should be + * defined. + * @return A pass-through DocumentPostProcessor that returns input as-is + */ + @Bean + @ConditionalOnMissingBean + public DocumentPostProcessor noOpPostProcessor() { + return (query, documents) -> documents; + } + +} \ No newline at end of file diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankResponse.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankResponse.java new file mode 100644 index 00000000000..bdf1861ceea --- /dev/null +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankResponse.java @@ -0,0 +1,54 @@ +package org.springframework.ai.rag.postretrieval.rerank; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Represents the response returned from Cohere's Rerank API. The response includes a list + * of result objects that specify document indices and their semantic relevance scores. + * + * @author KoreaNirsa + */ +public class RerankResponse { + + private List results; + + public List getResults() { + return results; + } + + public void setResults(List results) { + this.results = results; + } + + /** + * Represents a single reranked document result returned by the Cohere API. Contains + * the original index and the computed relevance score. + */ + public static class Result { + + private int index; + + @JsonProperty("relevance_score") + private double relevanceScore; + + public int getIndex() { + return index; + } + + public void setIndex(int index) { + this.index = index; + } + + public double getRelevanceScore() { + return relevanceScore; + } + + public void setRelevanceScore(double relevanceScore) { + this.relevanceScore = relevanceScore; + } + + } + +} \ No newline at end of file diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankerPostProcessor.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankerPostProcessor.java new file mode 100644 index 00000000000..0998b46a1f6 --- /dev/null +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankerPostProcessor.java @@ -0,0 +1,50 @@ +package org.springframework.ai.rag.postretrieval.rerank; + +import java.util.List; + +import org.springframework.ai.document.Document; +import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor; + +/** + * The only supported entrypoint for rerank functionality in Spring AI RAG. This component + * delegates reranking logic to CohereReranker, using the provided API key. + * + * This class is registered as a DocumentPostProcessor bean only if + * spring.ai.rerank.enabled=true is set in the application properties. + * + * @author KoreaNirsa + */ +public class RerankerPostProcessor implements DocumentPostProcessor { + + private final CohereReranker reranker; + + RerankerPostProcessor(CohereApi cohereApi) { + this.reranker = new CohereReranker(cohereApi); + } + + /** + * Processes the retrieved documents by applying semantic reranking using the Cohere + * API + * @param query the user's input query + * @param documents the list of documents to be reranked + * @return a list of documents sorted by relevance score + */ + @Override + public List process(Query query, List documents) { + int topN = extractTopN(query); + return reranker.rerank(query.text(), documents, topN); + } + + /** + * Extracts the top-N value from the query context. If not present or invalid, it + * defaults to 3 + * @param query the query containing optional context parameters + * @return the number of top documents to return + */ + private int extractTopN(Query query) { + Object value = query.context().get("topN"); + return (value instanceof Number num) ? num.intValue() : 3; + } + +}