Skip to content

Feat(rerank): Add Cohere Reranker with topN filtering and fallback support #3791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions spring-ai-rag/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@
<artifactId>jackson-module-kotlin</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-webflux</artifactId>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.springframework.ai.rag.postretrieval.rerank;

/**
* Represents the API key holder for Cohere API authentication.
*
* @author KoreaNirsa
*/
public class CohereApi {

private String apiKey;

public static Builder builder() {
return new Builder();
}

public String getApiKey() {
return apiKey;
}

public static class Builder {

private final CohereApi instance = new CohereApi();

public Builder apiKey(String key) {
instance.apiKey = key;
return this;
}

public CohereApi build() {
if (instance.apiKey == null || instance.apiKey.isBlank()) {
throw new IllegalArgumentException("API key must be provided.");
}
return instance;
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package org.springframework.ai.rag.postretrieval.rerank;

import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.http.HttpHeaders;

/**
* A Reranker implementation that integrates with Cohere's Rerank API. This component
* reorders retrieved documents based on semantic relevance to the input query.
*
* @author KoreaNirsa
* @see <a href="https://docs.cohere.com/reference/rerank">Cohere Rerank API
* Documentation</a>
*/
public class CohereReranker {

private static final String COHERE_RERANK_ENDPOINT = "https://api.cohere.ai/v1/rerank";

private static final Logger logger = LoggerFactory.getLogger(CohereReranker.class);

private static final int MAX_DOCUMENTS = 1000;

private final WebClient webClient;

/**
* Constructs a CohereReranker that communicates with the Cohere Rerank API.
* Initializes the internal WebClient with the provided API key for authorization.
* @param cohereApi the API configuration object containing the required API key (must
* not be null)
* @throws IllegalArgumentException if cohereApi is null
*/
CohereReranker(CohereApi cohereApi) {
if (cohereApi == null) {
throw new IllegalArgumentException("CohereApi must not be null");
}

this.webClient = WebClient.builder()
.baseUrl(COHERE_RERANK_ENDPOINT)
.defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + cohereApi.getApiKey())
.build();
}

/**
* Reranks a list of documents based on the provided query using the Cohere API.
* @param query The user input query.
* @param documents The list of documents to rerank.
* @param topN The number of top results to return (at most).
* @return A reranked list of documents. If the API fails, returns the original list.
*/
public List<Document> rerank(String query, List<Document> documents, int topN) {
if (topN < 1) {
throw new IllegalArgumentException("topN must be ≥ 1. Provided: " + topN);
}

if (documents == null || documents.isEmpty()) {
logger.warn("Empty document list provided. Skipping rerank.");
return Collections.emptyList();
}

if (documents.size() > MAX_DOCUMENTS) {
logger.warn("Cohere recommends ≤ {} documents per rerank request. Larger sizes may cause errors.",
MAX_DOCUMENTS);
return documents;
}

int adjustedTopN = Math.min(topN, documents.size());

Map<String, Object> payload = Map.of("query", query, "documents",
documents.stream().map(Document::getText).toList(), "top_n", adjustedTopN);

// Call the API and process the result
return sendRerankRequest(payload).map(results -> results.stream()
.sorted(Comparator.comparingDouble(RerankResponse.Result::getRelevanceScore).reversed())
.map(r -> {
Document original = documents.get(r.getIndex());
Map<String, Object> metadata = new HashMap<>(original.getMetadata());
metadata.put("score", String.format("%.4f", r.getRelevanceScore()));
return new Document(original.getText(), metadata);
})
.toList()).orElseGet(() -> {
logger.warn("Cohere response is null or invalid");
return documents;
});
}

/**
* Sends a rerank request to the Cohere API and returns the result list.
* @param payload The request body including query, documents, and top_n.
* @return An Optional list of reranked results, or empty if failed.
*/
private Optional<List<RerankResponse.Result>> sendRerankRequest(Map<String, Object> payload) {
try {
RerankResponse response = webClient.post()
.bodyValue(payload)
.retrieve()
.bodyToMono(RerankResponse.class)
.block();

return Optional.ofNullable(response).map(RerankResponse::getResults);
}
catch (Exception e) {
logger.error("Cohere rerank failed, fallback to original order: {}", e.getMessage(), e);
return Optional.empty();
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package org.springframework.ai.rag.postretrieval.rerank;

import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

/**
* Rerank configuration that conditionally registers a DocumentPostProcessor when rerank
* is enabled via application properties.
*
* This configuration is activated only when the following properties are set
*
* <ul>
* <li>spring.ai.rerank.enabled=true</li>
* <li>spring.ai.rerank.cohere.api-key=your-api-key</li>
* </ul>
*
* @author KoreaNirsa
*/
@Configuration
public class RerankConfig {

@Value("${spring.ai.rerank.cohere.api-key}")
private String apiKey;

/**
* Registers a DocumentPostProcessor bean that enables reranking using Cohere.
*
* This bean is only created when the property `spring.ai.rerank.enabled=true` is set.
* The API key is injected from application properties or environment variables.
* @return An instance of RerankerPostProcessor backed by Cohere API
*/
@Bean
@ConditionalOnProperty(name = "spring.ai.rerank.enabled", havingValue = "true")
public DocumentPostProcessor rerankerPostProcessor() {
return new RerankerPostProcessor(CohereApi.builder().apiKey(apiKey).build());
}

/**
* Provides a fallback DocumentPostProcessor when reranking is disabled or no custom
* implementation is registered.
*
* This implementation performs no reranking and simply returns the original list of
* documents. If additional post-processing is required, a custom bean should be
* defined.
* @return A pass-through DocumentPostProcessor that returns input as-is
*/
@Bean
@ConditionalOnMissingBean
public DocumentPostProcessor noOpPostProcessor() {
return (query, documents) -> documents;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package org.springframework.ai.rag.postretrieval.rerank;

import java.util.List;

import com.fasterxml.jackson.annotation.JsonProperty;

/**
* Represents the response returned from Cohere's Rerank API. The response includes a list
* of result objects that specify document indices and their semantic relevance scores.
*
* @author KoreaNirsa
*/
public class RerankResponse {

private List<Result> results;

public List<Result> getResults() {
return results;
}

public void setResults(List<Result> results) {
this.results = results;
}

/**
* Represents a single reranked document result returned by the Cohere API. Contains
* the original index and the computed relevance score.
*/
public static class Result {

private int index;

@JsonProperty("relevance_score")
private double relevanceScore;

public int getIndex() {
return index;
}

public void setIndex(int index) {
this.index = index;
}

public double getRelevanceScore() {
return relevanceScore;
}

public void setRelevanceScore(double relevanceScore) {
this.relevanceScore = relevanceScore;
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package org.springframework.ai.rag.postretrieval.rerank;

import java.util.List;

import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor;

/**
* The only supported entrypoint for rerank functionality in Spring AI RAG. This component
* delegates reranking logic to CohereReranker, using the provided API key.
*
* This class is registered as a DocumentPostProcessor bean only if
* spring.ai.rerank.enabled=true is set in the application properties.
*
* @author KoreaNirsa
*/
public class RerankerPostProcessor implements DocumentPostProcessor {

private final CohereReranker reranker;

RerankerPostProcessor(CohereApi cohereApi) {
this.reranker = new CohereReranker(cohereApi);
}

/**
* Processes the retrieved documents by applying semantic reranking using the Cohere
* API
* @param query the user's input query
* @param documents the list of documents to be reranked
* @return a list of documents sorted by relevance score
*/
@Override
public List<Document> process(Query query, List<Document> documents) {
int topN = extractTopN(query);
return reranker.rerank(query.text(), documents, topN);
}

/**
* Extracts the top-N value from the query context. If not present or invalid, it
* defaults to 3
* @param query the query containing optional context parameters
* @return the number of top documents to return
*/
private int extractTopN(Query query) {
Object value = query.context().get("topN");
return (value instanceof Number num) ? num.intValue() : 3;
}

}