Skip to content

Commit 1ca8cc6

Browse files
committed
fix(redis): Implement ChatMemoryRepository interface and fix test connectivity
Refactor Redis-based chat memory implementation to: - Implement ChatMemoryRepository interface as requested in PR #2295 - Fix Redis connection issues in integration tests reported in PR #2982 - Optimize conversation ID lookup with server-side deduplication - Add configurable result limits to avoid Redis cursor size limitations - Implement robust fallback mechanism for query failures - Enhance support for metadata, toolcalls, and media in messages - Add comprehensive test coverage with reliable Redis connections Signed-off-by: Brian Sam-Bodden <[email protected]>
1 parent 00a4e2e commit 1ca8cc6

File tree

10 files changed

+552
-137
lines changed

10 files changed

+552
-137
lines changed

auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-redis/src/test/java/org/springframework/ai/vectorstore/redis/autoconfigure/RedisVectorStoreAutoConfigurationIT.java

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -49,6 +49,7 @@
4949
* @author Soby Chacko
5050
* @author Christian Tzolov
5151
* @author Thomas Vitale
52+
* @author Brian Sam-Bodden
5253
*/
5354
@Testcontainers
5455
class RedisVectorStoreAutoConfigurationIT {
@@ -57,10 +58,13 @@ class RedisVectorStoreAutoConfigurationIT {
5758
static RedisStackContainer redisContainer = new RedisStackContainer(
5859
RedisStackContainer.DEFAULT_IMAGE_NAME.withTag(RedisStackContainer.DEFAULT_TAG));
5960

61+
// Use host and port explicitly since getRedisURI() might not be consistent
6062
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
6163
.withConfiguration(AutoConfigurations.of(RedisAutoConfiguration.class, RedisVectorStoreAutoConfiguration.class))
6264
.withUserConfiguration(Config.class)
63-
.withPropertyValues("spring.data.redis.url=" + redisContainer.getRedisURI())
65+
.withPropertyValues(
66+
"spring.data.redis.host=" + redisContainer.getHost(),
67+
"spring.data.redis.port=" + redisContainer.getFirstMappedPort())
6468
.withPropertyValues("spring.ai.vectorstore.redis.initialize-schema=true")
6569
.withPropertyValues("spring.ai.vectorstore.redis.index=myIdx")
6670
.withPropertyValues("spring.ai.vectorstore.redis.prefix=doc:");
@@ -148,5 +152,4 @@ public EmbeddingModel embeddingModel() {
148152
}
149153

150154
}
151-
152-
}
155+
}

vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/memory/redis/RedisChatMemory.java

+188-7
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,21 @@
2020
import org.slf4j.Logger;
2121
import org.slf4j.LoggerFactory;
2222
import org.springframework.ai.chat.memory.ChatMemory;
23+
import org.springframework.ai.chat.memory.ChatMemoryRepository;
2324
import org.springframework.ai.chat.messages.AssistantMessage;
2425
import org.springframework.ai.chat.messages.Message;
2526
import org.springframework.ai.chat.messages.MessageType;
2627
import org.springframework.ai.chat.messages.UserMessage;
28+
import org.springframework.ai.content.Media;
29+
import org.springframework.ai.content.MediaContent;
2730
import org.springframework.util.Assert;
2831
import redis.clients.jedis.JedisPooled;
2932
import redis.clients.jedis.Pipeline;
3033
import redis.clients.jedis.json.Path2;
3134
import redis.clients.jedis.search.*;
35+
import redis.clients.jedis.search.aggr.AggregationBuilder;
36+
import redis.clients.jedis.search.aggr.AggregationResult;
37+
import redis.clients.jedis.search.aggr.Reducers;
3238
import redis.clients.jedis.search.schemafields.NumericField;
3339
import redis.clients.jedis.search.schemafields.SchemaField;
3440
import redis.clients.jedis.search.schemafields.TagField;
@@ -37,17 +43,20 @@
3743
import java.time.Duration;
3844
import java.time.Instant;
3945
import java.util.ArrayList;
46+
import java.util.HashMap;
47+
import java.util.HashSet;
4048
import java.util.List;
4149
import java.util.Map;
50+
import java.util.Set;
4251
import java.util.concurrent.atomic.AtomicLong;
4352

4453
/**
45-
* Redis implementation of {@link ChatMemory} using Redis Stack (RedisJSON + RediSearch).
46-
* Stores chat messages as JSON documents and uses RediSearch for querying.
54+
* Redis implementation of {@link ChatMemory} using Redis (JSON + Query Engine). Stores
55+
* chat messages as JSON documents and uses the Redis Query Engine for querying.
4756
*
4857
* @author Brian Sam-Bodden
4958
*/
50-
public final class RedisChatMemory implements ChatMemory {
59+
public final class RedisChatMemory implements ChatMemory, ChatMemoryRepository {
5160

5261
private static final Logger logger = LoggerFactory.getLogger(RedisChatMemory.class);
5362

@@ -113,26 +122,79 @@ public List<Message> get(String conversationId, int lastN) {
113122
Assert.isTrue(lastN > 0, "LastN must be greater than 0");
114123

115124
String queryStr = String.format("@conversation_id:{%s}", RediSearchUtil.escape(conversationId));
125+
// Use ascending order (oldest first) to match test expectations
116126
Query query = new Query(queryStr).setSortBy("timestamp", true).limit(0, lastN);
117127

118128
SearchResult result = jedis.ftSearch(config.getIndexName(), query);
119129

130+
if (logger.isDebugEnabled()) {
131+
logger.debug("Redis search for conversation {} returned {} results", conversationId,
132+
result.getDocuments().size());
133+
result.getDocuments().forEach(doc -> {
134+
if (doc.get("$") != null) {
135+
JsonObject json = gson.fromJson(doc.getString("$"), JsonObject.class);
136+
logger.debug("Document: {}", json);
137+
}
138+
});
139+
}
140+
120141
List<Message> messages = new ArrayList<>();
121142
result.getDocuments().forEach(doc -> {
122143
if (doc.get("$") != null) {
123144
JsonObject json = gson.fromJson(doc.getString("$"), JsonObject.class);
124145
String type = json.get("type").getAsString();
125146
String content = json.get("content").getAsString();
126147

148+
// Convert metadata from JSON to Map if present
149+
Map<String, Object> metadata = new HashMap<>();
150+
if (json.has("metadata") && json.get("metadata").isJsonObject()) {
151+
JsonObject metadataJson = json.getAsJsonObject("metadata");
152+
metadataJson.entrySet().forEach(entry -> {
153+
metadata.put(entry.getKey(), gson.fromJson(entry.getValue(), Object.class));
154+
});
155+
}
156+
127157
if (MessageType.ASSISTANT.toString().equals(type)) {
128-
messages.add(new AssistantMessage(content));
158+
// Handle tool calls if present
159+
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();
160+
if (json.has("toolCalls") && json.get("toolCalls").isJsonArray()) {
161+
json.getAsJsonArray("toolCalls").forEach(element -> {
162+
JsonObject toolCallJson = element.getAsJsonObject();
163+
toolCalls.add(new AssistantMessage.ToolCall(
164+
toolCallJson.has("id") ? toolCallJson.get("id").getAsString() : "",
165+
toolCallJson.has("type") ? toolCallJson.get("type").getAsString() : "",
166+
toolCallJson.has("name") ? toolCallJson.get("name").getAsString() : "",
167+
toolCallJson.has("arguments") ? toolCallJson.get("arguments").getAsString() : ""));
168+
});
169+
}
170+
171+
// Handle media if present
172+
List<Media> media = new ArrayList<>();
173+
if (json.has("media") && json.get("media").isJsonArray()) {
174+
// Media deserialization would go here if needed
175+
// Left as empty list for simplicity
176+
}
177+
178+
messages.add(new AssistantMessage(content, metadata, toolCalls, media));
129179
}
130180
else if (MessageType.USER.toString().equals(type)) {
131-
messages.add(new UserMessage(content));
181+
// Create a UserMessage with the builder to properly set metadata
182+
List<Media> userMedia = new ArrayList<>();
183+
if (json.has("media") && json.get("media").isJsonArray()) {
184+
// Media deserialization would go here if needed
185+
}
186+
messages.add(UserMessage.builder().text(content).metadata(metadata).media(userMedia).build());
132187
}
188+
// Add handling for other message types if needed
133189
}
134190
});
135191

192+
if (logger.isDebugEnabled()) {
193+
logger.debug("Returning {} messages for conversation {}", messages.size(), conversationId);
194+
messages.forEach(message -> logger.debug("Message type: {}, content: {}", message.getMessageType(),
195+
message.getText()));
196+
}
197+
136198
return messages;
137199
}
138200

@@ -179,14 +241,133 @@ private String createKey(String conversationId, long timestamp) {
179241
}
180242

181243
private Map<String, Object> createMessageDocument(String conversationId, Message message) {
182-
return Map.of("type", message.getMessageType().toString(), "content", message.getText(), "conversation_id",
183-
conversationId, "timestamp", Instant.now().toEpochMilli());
244+
Map<String, Object> documentMap = new HashMap<>();
245+
documentMap.put("type", message.getMessageType().toString());
246+
documentMap.put("content", message.getText());
247+
documentMap.put("conversation_id", conversationId);
248+
documentMap.put("timestamp", Instant.now().toEpochMilli());
249+
250+
// Store metadata/properties
251+
if (message.getMetadata() != null && !message.getMetadata().isEmpty()) {
252+
documentMap.put("metadata", message.getMetadata());
253+
}
254+
255+
// Handle tool calls for AssistantMessage
256+
if (message instanceof AssistantMessage assistantMessage && assistantMessage.hasToolCalls()) {
257+
documentMap.put("toolCalls", assistantMessage.getToolCalls());
258+
}
259+
260+
// Handle media content
261+
if (message instanceof MediaContent mediaContent && !mediaContent.getMedia().isEmpty()) {
262+
documentMap.put("media", mediaContent.getMedia());
263+
}
264+
265+
return documentMap;
184266
}
185267

186268
private String escapeKey(String key) {
187269
return key.replace(":", "\\:");
188270
}
189271

272+
// ChatMemoryRepository implementation
273+
274+
/**
275+
* Finds all unique conversation IDs using Redis aggregation. This method is optimized
276+
* to perform the deduplication on the Redis server side.
277+
* @return a list of unique conversation IDs
278+
*/
279+
@Override
280+
public List<String> findConversationIds() {
281+
try {
282+
// Use Redis aggregation to get distinct conversation_ids
283+
AggregationBuilder aggregation = new AggregationBuilder("*")
284+
.groupBy("@conversation_id", Reducers.count().as("count"))
285+
.limit(0, config.getMaxConversationIds()); // Use configured limit
286+
287+
AggregationResult result = jedis.ftAggregate(config.getIndexName(), aggregation);
288+
289+
List<String> conversationIds = new ArrayList<>();
290+
result.getResults().forEach(row -> {
291+
String conversationId = (String) row.get("conversation_id");
292+
if (conversationId != null) {
293+
conversationIds.add(conversationId);
294+
}
295+
});
296+
297+
if (logger.isDebugEnabled()) {
298+
logger.debug("Found {} unique conversation IDs using Redis aggregation", conversationIds.size());
299+
conversationIds.forEach(id -> logger.debug("Conversation ID: {}", id));
300+
}
301+
302+
return conversationIds;
303+
}
304+
catch (Exception e) {
305+
logger.warn("Error executing Redis aggregation for conversation IDs, falling back to client-side approach",
306+
e);
307+
return findConversationIdsLegacy();
308+
}
309+
}
310+
311+
/**
312+
* Fallback method to find conversation IDs if aggregation fails. This is less
313+
* efficient as it requires fetching all documents and deduplicating on the client
314+
* side.
315+
* @return a list of unique conversation IDs
316+
*/
317+
private List<String> findConversationIdsLegacy() {
318+
// Keep the current implementation as a fallback
319+
String queryStr = "*"; // Match all documents
320+
Query query = new Query(queryStr);
321+
query.limit(0, config.getMaxConversationIds()); // Use configured limit
322+
323+
SearchResult result = jedis.ftSearch(config.getIndexName(), query);
324+
325+
// Use a Set to deduplicate conversation IDs
326+
Set<String> conversationIds = new HashSet<>();
327+
328+
result.getDocuments().forEach(doc -> {
329+
if (doc.get("$") != null) {
330+
JsonObject json = gson.fromJson(doc.getString("$"), JsonObject.class);
331+
if (json.has("conversation_id")) {
332+
conversationIds.add(json.get("conversation_id").getAsString());
333+
}
334+
}
335+
});
336+
337+
if (logger.isDebugEnabled()) {
338+
logger.debug("Found {} unique conversation IDs using legacy method", conversationIds.size());
339+
}
340+
341+
return new ArrayList<>(conversationIds);
342+
}
343+
344+
/**
345+
* Finds all messages for a given conversation ID. Uses the configured maximum
346+
* messages per conversation limit to avoid exceeding Redis limits.
347+
* @param conversationId the conversation ID to find messages for
348+
* @return a list of messages for the conversation
349+
*/
350+
@Override
351+
public List<Message> findByConversationId(String conversationId) {
352+
// Reuse existing get method with the configured limit
353+
return get(conversationId, config.getMaxMessagesPerConversation());
354+
}
355+
356+
@Override
357+
public void saveAll(String conversationId, List<Message> messages) {
358+
// First clear any existing messages for this conversation
359+
clear(conversationId);
360+
361+
// Then add all the new messages
362+
add(conversationId, messages);
363+
}
364+
365+
@Override
366+
public void deleteByConversationId(String conversationId) {
367+
// Reuse existing clear method
368+
clear(conversationId);
369+
}
370+
190371
/**
191372
* Builder for RedisChatMemory configuration.
192373
*/

0 commit comments

Comments
 (0)