From 52019706066693ff24ab14288a0a855a8a4c432c Mon Sep 17 00:00:00 2001 From: nirsa Date: Fri, 11 Jul 2025 15:41:36 +0900 Subject: [PATCH 1/4] Feat(rerank): Add Cohere Reranker support with topN result filtering - Implemented rerank logic using Cohere Rerank API with WebClient Signed-off-by: nirsa --- .../rag/postretrieval/rerank/CohereApi.java | 34 ++++++ .../postretrieval/rerank/CohereReranker.java | 113 ++++++++++++++++++ .../postretrieval/rerank/RerankConfig.java | 32 +++++ .../postretrieval/rerank/RerankResponse.java | 51 ++++++++ .../rerank/RerankerPostProcessor.java | 49 ++++++++ 5 files changed, 279 insertions(+) create mode 100644 spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereApi.java create mode 100644 spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereReranker.java create mode 100644 spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankConfig.java create mode 100644 spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankResponse.java create mode 100644 spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankerPostProcessor.java diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereApi.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereApi.java new file mode 100644 index 00000000000..3876b0d90ab --- /dev/null +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereApi.java @@ -0,0 +1,34 @@ +package org.springframework.ai.rag.postretrieval.rerank; + +/** + * Represents the API key holder for Cohere API authentication. + * + * @author KoreaNirsa + */ +public class CohereApi { + private String apiKey; + + public static Builder builder() { + return new Builder(); + } + + public String getApiKey() { + return apiKey; + } + + public static class Builder { + private final CohereApi instance = new CohereApi(); + + public Builder apiKey(String key) { + instance.apiKey = key; + return this; + } + + public CohereApi build() { + if (instance.apiKey == null || instance.apiKey.isBlank()) { + throw new IllegalArgumentException("API key must be provided."); + } + return instance; + } + } +} diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereReranker.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereReranker.java new file mode 100644 index 00000000000..ec93660742e --- /dev/null +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereReranker.java @@ -0,0 +1,113 @@ +package org.springframework.ai.rag.postretrieval.rerank; + +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.http.HttpHeaders; +import org.springframework.web.reactive.function.client.WebClient; + +/** + * A Reranker implementation that integrates with Cohere's Rerank API. + * This component reorders retrieved documents based on semantic relevance to the input query. + * + * @author KoreaNirsa + * @see Cohere Rerank API Documentation + */ +public class CohereReranker { + private static final String COHERE_RERANK_ENDPOINT = "https://api.cohere.ai/v1/rerank"; + + private static final Logger logger = LoggerFactory.getLogger(CohereReranker.class); + + private static final int MAX_DOCUMENTS = 1000; + + private final WebClient webClient; + + /** + * Constructs a CohereReranker that communicates with the Cohere Rerank API. + * Initializes the internal WebClient with the provided API key for authorization. + * + * @param cohereApi the API configuration object containing the required API key (must not be null) + * @throws IllegalArgumentException if cohereApi is null + */ + CohereReranker(CohereApi cohereApi) { + if (cohereApi == null) { + throw new IllegalArgumentException("CohereApi must not be null"); + } + + this.webClient = WebClient.builder() + .baseUrl(COHERE_RERANK_ENDPOINT) + .defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + cohereApi.getApiKey()) + .build(); + } + + /** + * Reranks a list of documents based on the provided query using the Cohere API. + * + * @param query The user input query. + * @param documents The list of documents to rerank. + * @param topN The number of top results to return (at most). + * @return A reranked list of documents. If the API fails, returns the original list. + */ + public List rerank(String query, List documents, int topN) { + if (topN < 1) { + throw new IllegalArgumentException("topN must be ≥ 1. Provided: " + topN); + } + + if (documents == null || documents.isEmpty()) { + logger.warn("Empty document list provided. Skipping rerank."); + return Collections.emptyList(); + } + + if (documents.size() > MAX_DOCUMENTS) { + logger.warn("Cohere recommends ≤ {} documents per rerank request. Larger sizes may cause errors.", MAX_DOCUMENTS); + return documents; + } + + int adjustedTopN = Math.min(topN, documents.size()); + + Map payload = Map.of( + "query", query, + "documents", documents.stream().map(Document::getText).toList(), + "top_n", adjustedTopN + ); + + // Call the API and process the result + return sendRerankRequest(payload) + .map(results -> results.stream() + .sorted(Comparator.comparingDouble(RerankResponse.Result::getRelevanceScore).reversed()) + .map(r -> documents.get(r.getIndex())) + .toList()) + .orElseGet(() -> { + logger.warn("Cohere response is null or invalid"); + return documents; + }); + } + + /** + * Sends a rerank request to the Cohere API and returns the result list. + * + * @param payload The request body including query, documents, and top_n. + * @return An Optional list of reranked results, or empty if failed. + */ + private Optional> sendRerankRequest(Map payload) { + try { + RerankResponse response = webClient.post() + .bodyValue(payload) + .retrieve() + .bodyToMono(RerankResponse.class) + .block(); + + return Optional.ofNullable(response) + .map(RerankResponse::getResults); + } catch (Exception e) { + logger.error("Cohere rerank failed, fallback to original order: {}", e.getMessage(), e); + return Optional.empty(); + } + } +} diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankConfig.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankConfig.java new file mode 100644 index 00000000000..1a780b3da3d --- /dev/null +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankConfig.java @@ -0,0 +1,32 @@ +package org.springframework.ai.rag.postretrieval.rerank; + +import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +/** + * Rerank configuration that conditionally registers a DocumentPostProcessor + * when rerank is enabled via application properties. + * + * This configuration is activated only when the following properties are set + * + *
    + *
  • spring.ai.rerank.enabled=true
  • + *
  • spring.ai.rerank.cohere.api-key=your-api-key
  • + *
+ * + * @author KoreaNirsa + */ +@Configuration +public class RerankConfig { + @Value("${spring.ai.rerank.cohere.api-key}") + private String apiKey; + + @Bean + @ConditionalOnProperty(name = "spring.ai.rerank.enabled", havingValue = "true") + public DocumentPostProcessor rerankerPostProcessor() { + return new RerankerPostProcessor(CohereApi.builder().apiKey(apiKey).build()); + } +} \ No newline at end of file diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankResponse.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankResponse.java new file mode 100644 index 00000000000..d69c71a388a --- /dev/null +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankResponse.java @@ -0,0 +1,51 @@ +package org.springframework.ai.rag.postretrieval.rerank; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Represents the response returned from Cohere's Rerank API. + * The response includes a list of result objects that specify document indices + * and their semantic relevance scores. + * + * @author KoreaNirsa + */ +public class RerankResponse { + private List results; + + public List getResults() { + return results; + } + + public void setResults(List results) { + this.results = results; + } + + /** + * Represents a single reranked document result returned by the Cohere API. + * Contains the original index and the computed relevance score. + */ + public static class Result { + private int index; + + @JsonProperty("relevance_score") + private int relevanceScore; + + public int getIndex() { + return index; + } + + public void setIndex(int index) { + this.index = index; + } + + public int getRelevanceScore() { + return relevanceScore; + } + + public void setRelevanceScore(int relevanceScore) { + this.relevanceScore = relevanceScore; + } + } +} \ No newline at end of file diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankerPostProcessor.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankerPostProcessor.java new file mode 100644 index 00000000000..7d4d4215a43 --- /dev/null +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankerPostProcessor.java @@ -0,0 +1,49 @@ +package org.springframework.ai.rag.postretrieval.rerank; + +import java.util.List; + +import org.springframework.ai.document.Document; +import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor; + +/** + * The only supported entrypoint for rerank functionality in Spring AI RAG. + * This component delegates reranking logic to CohereReranker, using the provided API key. + * + * This class is registered as a DocumentPostProcessor bean only if + * spring.ai.rerank.enabled=true is set in the application properties. + * + * @author KoreaNirsa + */ +public class RerankerPostProcessor implements DocumentPostProcessor { + private final CohereReranker reranker; + + RerankerPostProcessor(CohereApi cohereApi) { + this.reranker = new CohereReranker(cohereApi); + } + + /** + * Processes the retrieved documents by applying semantic reranking using the Cohere API + * + * @param query the user's input query + * @param documents the list of documents to be reranked + * @return a list of documents sorted by relevance score + */ + @Override + public List process(Query query, List documents) { + int topN = extractTopN(query); + return reranker.rerank(query.text(), documents, topN); + } + + /** + * Extracts the top-N value from the query context. + * If not present or invalid, it defaults to 3 + * + * @param query the query containing optional context parameters + * @return the number of top documents to return + */ + private int extractTopN(Query query) { + Object value = query.context().get("topN"); + return (value instanceof Number num) ? num.intValue() : 3; + } +} From 3d44a5d329381f771b0dc767cbaa709b91f238cf Mon Sep 17 00:00:00 2001 From: nirsa Date: Fri, 11 Jul 2025 15:41:50 +0900 Subject: [PATCH 2/4] Feat(rerank): Add no-op DocumentPostProcessor as fallback - Register a default DocumentPostProcessor that returns documents as-is Signed-off-by: nirsa --- .../postretrieval/rerank/RerankConfig.java | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankConfig.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankConfig.java index 1a780b3da3d..576b57fabc4 100644 --- a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankConfig.java +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankConfig.java @@ -2,6 +2,7 @@ import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor; import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -24,9 +25,32 @@ public class RerankConfig { @Value("${spring.ai.rerank.cohere.api-key}") private String apiKey; + /** + * Registers a DocumentPostProcessor bean that enables reranking using Cohere. + * + * This bean is only created when the property `spring.ai.rerank.enabled=true` is set. + * The API key is injected from application properties or environment variables. + * + * @return An instance of RerankerPostProcessor backed by Cohere API + */ @Bean @ConditionalOnProperty(name = "spring.ai.rerank.enabled", havingValue = "true") public DocumentPostProcessor rerankerPostProcessor() { return new RerankerPostProcessor(CohereApi.builder().apiKey(apiKey).build()); } + + /** + * Provides a fallback DocumentPostProcessor when reranking is disabled + * or no custom implementation is registered. + * + * This implementation performs no reranking and simply returns the original list of documents. + * If additional post-processing is required, a custom bean should be defined. + * + * @return A pass-through DocumentPostProcessor that returns input as-is + */ + @Bean + @ConditionalOnMissingBean + public DocumentPostProcessor noOpPostProcessor() { + return (query, documents) -> documents; + } } \ No newline at end of file From c18820443af8e6f180b9212d5ff3ed801c787fa7 Mon Sep 17 00:00:00 2001 From: nirsa Date: Fri, 11 Jul 2025 15:41:57 +0900 Subject: [PATCH 3/4] Fix(rerank): Correct relevanceScore setter to use double to avoid format exception - Changed setRelevanceScore parameter from int to double - Prevented IllegalFormatConversionException when formatting with %.4f Signed-off-by: nirsa --- .../ai/rag/postretrieval/rerank/CohereReranker.java | 8 +++++++- .../ai/rag/postretrieval/rerank/RerankResponse.java | 6 +++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereReranker.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereReranker.java index ec93660742e..aa7bf16c7ac 100644 --- a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereReranker.java +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereReranker.java @@ -2,6 +2,7 @@ import java.util.Collections; import java.util.Comparator; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -81,7 +82,12 @@ public List rerank(String query, List documents, int topN) { return sendRerankRequest(payload) .map(results -> results.stream() .sorted(Comparator.comparingDouble(RerankResponse.Result::getRelevanceScore).reversed()) - .map(r -> documents.get(r.getIndex())) + .map(r -> { + Document original = documents.get(r.getIndex()); + Map metadata = new HashMap<>(original.getMetadata()); + metadata.put("score", String.format("%.4f", r.getRelevanceScore())); + return new Document(original.getText(), metadata); + }) .toList()) .orElseGet(() -> { logger.warn("Cohere response is null or invalid"); diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankResponse.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankResponse.java index d69c71a388a..3749dc16287 100644 --- a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankResponse.java +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankResponse.java @@ -30,7 +30,7 @@ public static class Result { private int index; @JsonProperty("relevance_score") - private int relevanceScore; + private double relevanceScore; public int getIndex() { return index; @@ -40,11 +40,11 @@ public void setIndex(int index) { this.index = index; } - public int getRelevanceScore() { + public double getRelevanceScore() { return relevanceScore; } - public void setRelevanceScore(int relevanceScore) { + public void setRelevanceScore(double relevanceScore) { this.relevanceScore = relevanceScore; } } From 1906c1f74a3d5225fe782df390217fc22eac483a Mon Sep 17 00:00:00 2001 From: nirsa Date: Fri, 11 Jul 2025 15:42:05 +0900 Subject: [PATCH 4/4] Chore: apply code formatting Signed-off-by: nirsa --- spring-ai-rag/pom.xml | 5 + .../rag/postretrieval/rerank/CohereApi.java | 4 + .../postretrieval/rerank/CohereReranker.java | 171 +++++++++--------- .../postretrieval/rerank/RerankConfig.java | 73 ++++---- .../postretrieval/rerank/RerankResponse.java | 81 +++++---- .../rerank/RerankerPostProcessor.java | 69 +++---- 6 files changed, 207 insertions(+), 196 deletions(-) diff --git a/spring-ai-rag/pom.xml b/spring-ai-rag/pom.xml index 01c8233ad31..b4a83dc8e1e 100644 --- a/spring-ai-rag/pom.xml +++ b/spring-ai-rag/pom.xml @@ -66,6 +66,11 @@ jackson-module-kotlin test + + + org.springframework.boot + spring-boot-starter-webflux + diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereApi.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereApi.java index 3876b0d90ab..6291db91935 100644 --- a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereApi.java +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereApi.java @@ -6,6 +6,7 @@ * @author KoreaNirsa */ public class CohereApi { + private String apiKey; public static Builder builder() { @@ -17,6 +18,7 @@ public String getApiKey() { } public static class Builder { + private final CohereApi instance = new CohereApi(); public Builder apiKey(String key) { @@ -30,5 +32,7 @@ public CohereApi build() { } return instance; } + } + } diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereReranker.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereReranker.java index aa7bf16c7ac..352cc7ac2fe 100644 --- a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereReranker.java +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/CohereReranker.java @@ -10,21 +10,23 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; -import org.springframework.http.HttpHeaders; import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.http.HttpHeaders; /** - * A Reranker implementation that integrates with Cohere's Rerank API. - * This component reorders retrieved documents based on semantic relevance to the input query. + * A Reranker implementation that integrates with Cohere's Rerank API. This component + * reorders retrieved documents based on semantic relevance to the input query. * * @author KoreaNirsa - * @see Cohere Rerank API Documentation + * @see Cohere Rerank API + * Documentation */ public class CohereReranker { + private static final String COHERE_RERANK_ENDPOINT = "https://api.cohere.ai/v1/rerank"; private static final Logger logger = LoggerFactory.getLogger(CohereReranker.class); - + private static final int MAX_DOCUMENTS = 1000; private final WebClient webClient; @@ -32,88 +34,83 @@ public class CohereReranker { /** * Constructs a CohereReranker that communicates with the Cohere Rerank API. * Initializes the internal WebClient with the provided API key for authorization. - * - * @param cohereApi the API configuration object containing the required API key (must not be null) + * @param cohereApi the API configuration object containing the required API key (must + * not be null) * @throws IllegalArgumentException if cohereApi is null */ - CohereReranker(CohereApi cohereApi) { - if (cohereApi == null) { - throw new IllegalArgumentException("CohereApi must not be null"); - } - - this.webClient = WebClient.builder() - .baseUrl(COHERE_RERANK_ENDPOINT) - .defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + cohereApi.getApiKey()) - .build(); - } - - /** - * Reranks a list of documents based on the provided query using the Cohere API. - * - * @param query The user input query. - * @param documents The list of documents to rerank. - * @param topN The number of top results to return (at most). - * @return A reranked list of documents. If the API fails, returns the original list. - */ - public List rerank(String query, List documents, int topN) { - if (topN < 1) { - throw new IllegalArgumentException("topN must be ≥ 1. Provided: " + topN); - } - - if (documents == null || documents.isEmpty()) { - logger.warn("Empty document list provided. Skipping rerank."); - return Collections.emptyList(); - } - - if (documents.size() > MAX_DOCUMENTS) { - logger.warn("Cohere recommends ≤ {} documents per rerank request. Larger sizes may cause errors.", MAX_DOCUMENTS); - return documents; - } - - int adjustedTopN = Math.min(topN, documents.size()); - - Map payload = Map.of( - "query", query, - "documents", documents.stream().map(Document::getText).toList(), - "top_n", adjustedTopN - ); - - // Call the API and process the result - return sendRerankRequest(payload) - .map(results -> results.stream() - .sorted(Comparator.comparingDouble(RerankResponse.Result::getRelevanceScore).reversed()) - .map(r -> { - Document original = documents.get(r.getIndex()); - Map metadata = new HashMap<>(original.getMetadata()); - metadata.put("score", String.format("%.4f", r.getRelevanceScore())); - return new Document(original.getText(), metadata); - }) - .toList()) - .orElseGet(() -> { - logger.warn("Cohere response is null or invalid"); - return documents; - }); - } - - /** - * Sends a rerank request to the Cohere API and returns the result list. - * - * @param payload The request body including query, documents, and top_n. - * @return An Optional list of reranked results, or empty if failed. - */ - private Optional> sendRerankRequest(Map payload) { - try { - RerankResponse response = webClient.post() - .bodyValue(payload) - .retrieve() - .bodyToMono(RerankResponse.class) - .block(); - - return Optional.ofNullable(response) - .map(RerankResponse::getResults); - } catch (Exception e) { - logger.error("Cohere rerank failed, fallback to original order: {}", e.getMessage(), e); - return Optional.empty(); - } - } + CohereReranker(CohereApi cohereApi) { + if (cohereApi == null) { + throw new IllegalArgumentException("CohereApi must not be null"); + } + + this.webClient = WebClient.builder() + .baseUrl(COHERE_RERANK_ENDPOINT) + .defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + cohereApi.getApiKey()) + .build(); + } + + /** + * Reranks a list of documents based on the provided query using the Cohere API. + * @param query The user input query. + * @param documents The list of documents to rerank. + * @param topN The number of top results to return (at most). + * @return A reranked list of documents. If the API fails, returns the original list. + */ + public List rerank(String query, List documents, int topN) { + if (topN < 1) { + throw new IllegalArgumentException("topN must be ≥ 1. Provided: " + topN); + } + + if (documents == null || documents.isEmpty()) { + logger.warn("Empty document list provided. Skipping rerank."); + return Collections.emptyList(); + } + + if (documents.size() > MAX_DOCUMENTS) { + logger.warn("Cohere recommends ≤ {} documents per rerank request. Larger sizes may cause errors.", + MAX_DOCUMENTS); + return documents; + } + + int adjustedTopN = Math.min(topN, documents.size()); + + Map payload = Map.of("query", query, "documents", + documents.stream().map(Document::getText).toList(), "top_n", adjustedTopN); + + // Call the API and process the result + return sendRerankRequest(payload).map(results -> results.stream() + .sorted(Comparator.comparingDouble(RerankResponse.Result::getRelevanceScore).reversed()) + .map(r -> { + Document original = documents.get(r.getIndex()); + Map metadata = new HashMap<>(original.getMetadata()); + metadata.put("score", String.format("%.4f", r.getRelevanceScore())); + return new Document(original.getText(), metadata); + }) + .toList()).orElseGet(() -> { + logger.warn("Cohere response is null or invalid"); + return documents; + }); + } + + /** + * Sends a rerank request to the Cohere API and returns the result list. + * @param payload The request body including query, documents, and top_n. + * @return An Optional list of reranked results, or empty if failed. + */ + private Optional> sendRerankRequest(Map payload) { + try { + RerankResponse response = webClient.post() + .bodyValue(payload) + .retrieve() + .bodyToMono(RerankResponse.class) + .block(); + + return Optional.ofNullable(response).map(RerankResponse::getResults); + } + catch (Exception e) { + logger.error("Cohere rerank failed, fallback to original order: {}", e.getMessage(), e); + return Optional.empty(); + } + } + } diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankConfig.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankConfig.java index 576b57fabc4..6849a61be18 100644 --- a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankConfig.java +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankConfig.java @@ -8,49 +8,50 @@ import org.springframework.context.annotation.Configuration; /** - * Rerank configuration that conditionally registers a DocumentPostProcessor - * when rerank is enabled via application properties. + * Rerank configuration that conditionally registers a DocumentPostProcessor when rerank + * is enabled via application properties. * * This configuration is activated only when the following properties are set - * + * *
    - *
  • spring.ai.rerank.enabled=true
  • - *
  • spring.ai.rerank.cohere.api-key=your-api-key
  • + *
  • spring.ai.rerank.enabled=true
  • + *
  • spring.ai.rerank.cohere.api-key=your-api-key
  • *
* * @author KoreaNirsa */ @Configuration public class RerankConfig { - @Value("${spring.ai.rerank.cohere.api-key}") - private String apiKey; - - /** - * Registers a DocumentPostProcessor bean that enables reranking using Cohere. - * - * This bean is only created when the property `spring.ai.rerank.enabled=true` is set. - * The API key is injected from application properties or environment variables. - * - * @return An instance of RerankerPostProcessor backed by Cohere API - */ - @Bean - @ConditionalOnProperty(name = "spring.ai.rerank.enabled", havingValue = "true") - public DocumentPostProcessor rerankerPostProcessor() { - return new RerankerPostProcessor(CohereApi.builder().apiKey(apiKey).build()); - } - - /** - * Provides a fallback DocumentPostProcessor when reranking is disabled - * or no custom implementation is registered. - * - * This implementation performs no reranking and simply returns the original list of documents. - * If additional post-processing is required, a custom bean should be defined. - * - * @return A pass-through DocumentPostProcessor that returns input as-is - */ - @Bean - @ConditionalOnMissingBean - public DocumentPostProcessor noOpPostProcessor() { - return (query, documents) -> documents; - } + + @Value("${spring.ai.rerank.cohere.api-key}") + private String apiKey; + + /** + * Registers a DocumentPostProcessor bean that enables reranking using Cohere. + * + * This bean is only created when the property `spring.ai.rerank.enabled=true` is set. + * The API key is injected from application properties or environment variables. + * @return An instance of RerankerPostProcessor backed by Cohere API + */ + @Bean + @ConditionalOnProperty(name = "spring.ai.rerank.enabled", havingValue = "true") + public DocumentPostProcessor rerankerPostProcessor() { + return new RerankerPostProcessor(CohereApi.builder().apiKey(apiKey).build()); + } + + /** + * Provides a fallback DocumentPostProcessor when reranking is disabled or no custom + * implementation is registered. + * + * This implementation performs no reranking and simply returns the original list of + * documents. If additional post-processing is required, a custom bean should be + * defined. + * @return A pass-through DocumentPostProcessor that returns input as-is + */ + @Bean + @ConditionalOnMissingBean + public DocumentPostProcessor noOpPostProcessor() { + return (query, documents) -> documents; + } + } \ No newline at end of file diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankResponse.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankResponse.java index 3749dc16287..bdf1861ceea 100644 --- a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankResponse.java +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankResponse.java @@ -5,47 +5,50 @@ import com.fasterxml.jackson.annotation.JsonProperty; /** - * Represents the response returned from Cohere's Rerank API. - * The response includes a list of result objects that specify document indices - * and their semantic relevance scores. + * Represents the response returned from Cohere's Rerank API. The response includes a list + * of result objects that specify document indices and their semantic relevance scores. * * @author KoreaNirsa */ public class RerankResponse { - private List results; - - public List getResults() { - return results; - } - - public void setResults(List results) { - this.results = results; - } - - /** - * Represents a single reranked document result returned by the Cohere API. - * Contains the original index and the computed relevance score. - */ - public static class Result { - private int index; - - @JsonProperty("relevance_score") - private double relevanceScore; - - public int getIndex() { - return index; - } - - public void setIndex(int index) { - this.index = index; - } - - public double getRelevanceScore() { - return relevanceScore; - } - - public void setRelevanceScore(double relevanceScore) { - this.relevanceScore = relevanceScore; - } - } + + private List results; + + public List getResults() { + return results; + } + + public void setResults(List results) { + this.results = results; + } + + /** + * Represents a single reranked document result returned by the Cohere API. Contains + * the original index and the computed relevance score. + */ + public static class Result { + + private int index; + + @JsonProperty("relevance_score") + private double relevanceScore; + + public int getIndex() { + return index; + } + + public void setIndex(int index) { + this.index = index; + } + + public double getRelevanceScore() { + return relevanceScore; + } + + public void setRelevanceScore(double relevanceScore) { + this.relevanceScore = relevanceScore; + } + + } + } \ No newline at end of file diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankerPostProcessor.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankerPostProcessor.java index 7d4d4215a43..0998b46a1f6 100644 --- a/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankerPostProcessor.java +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/postretrieval/rerank/RerankerPostProcessor.java @@ -7,43 +7,44 @@ import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor; /** - * The only supported entrypoint for rerank functionality in Spring AI RAG. - * This component delegates reranking logic to CohereReranker, using the provided API key. - * + * The only supported entrypoint for rerank functionality in Spring AI RAG. This component + * delegates reranking logic to CohereReranker, using the provided API key. + * * This class is registered as a DocumentPostProcessor bean only if * spring.ai.rerank.enabled=true is set in the application properties. - * + * * @author KoreaNirsa */ public class RerankerPostProcessor implements DocumentPostProcessor { - private final CohereReranker reranker; - - RerankerPostProcessor(CohereApi cohereApi) { - this.reranker = new CohereReranker(cohereApi); - } - - /** - * Processes the retrieved documents by applying semantic reranking using the Cohere API - * - * @param query the user's input query - * @param documents the list of documents to be reranked - * @return a list of documents sorted by relevance score - */ - @Override - public List process(Query query, List documents) { - int topN = extractTopN(query); - return reranker.rerank(query.text(), documents, topN); - } - - /** - * Extracts the top-N value from the query context. - * If not present or invalid, it defaults to 3 - * - * @param query the query containing optional context parameters - * @return the number of top documents to return - */ - private int extractTopN(Query query) { - Object value = query.context().get("topN"); - return (value instanceof Number num) ? num.intValue() : 3; - } + + private final CohereReranker reranker; + + RerankerPostProcessor(CohereApi cohereApi) { + this.reranker = new CohereReranker(cohereApi); + } + + /** + * Processes the retrieved documents by applying semantic reranking using the Cohere + * API + * @param query the user's input query + * @param documents the list of documents to be reranked + * @return a list of documents sorted by relevance score + */ + @Override + public List process(Query query, List documents) { + int topN = extractTopN(query); + return reranker.rerank(query.text(), documents, topN); + } + + /** + * Extracts the top-N value from the query context. If not present or invalid, it + * defaults to 3 + * @param query the query containing optional context parameters + * @return the number of top documents to return + */ + private int extractTopN(Query query) { + Object value = query.context().get("topN"); + return (value instanceof Number num) ? num.intValue() : 3; + } + }