|
20 | 20 | import org.slf4j.Logger;
|
21 | 21 | import org.slf4j.LoggerFactory;
|
22 | 22 | import org.springframework.ai.chat.memory.ChatMemory;
|
| 23 | +import org.springframework.ai.chat.memory.ChatMemoryRepository; |
23 | 24 | import org.springframework.ai.chat.messages.AssistantMessage;
|
24 | 25 | import org.springframework.ai.chat.messages.Message;
|
25 | 26 | import org.springframework.ai.chat.messages.MessageType;
|
26 | 27 | import org.springframework.ai.chat.messages.UserMessage;
|
| 28 | +import org.springframework.ai.content.Media; |
| 29 | +import org.springframework.ai.content.MediaContent; |
27 | 30 | import org.springframework.util.Assert;
|
28 | 31 | import redis.clients.jedis.JedisPooled;
|
29 | 32 | import redis.clients.jedis.Pipeline;
|
30 | 33 | import redis.clients.jedis.json.Path2;
|
31 | 34 | import redis.clients.jedis.search.*;
|
| 35 | +import redis.clients.jedis.search.aggr.AggregationBuilder; |
| 36 | +import redis.clients.jedis.search.aggr.AggregationResult; |
| 37 | +import redis.clients.jedis.search.aggr.Reducers; |
32 | 38 | import redis.clients.jedis.search.schemafields.NumericField;
|
33 | 39 | import redis.clients.jedis.search.schemafields.SchemaField;
|
34 | 40 | import redis.clients.jedis.search.schemafields.TagField;
|
|
37 | 43 | import java.time.Duration;
|
38 | 44 | import java.time.Instant;
|
39 | 45 | import java.util.ArrayList;
|
| 46 | +import java.util.HashMap; |
| 47 | +import java.util.HashSet; |
40 | 48 | import java.util.List;
|
41 | 49 | import java.util.Map;
|
| 50 | +import java.util.Set; |
42 | 51 | import java.util.concurrent.atomic.AtomicLong;
|
43 | 52 |
|
44 | 53 | /**
|
45 |
| - * Redis implementation of {@link ChatMemory} using Redis Stack (RedisJSON + RediSearch). |
46 |
| - * Stores chat messages as JSON documents and uses RediSearch for querying. |
| 54 | + * Redis implementation of {@link ChatMemory} using Redis (JSON + Query Engine). Stores |
| 55 | + * chat messages as JSON documents and uses the Redis Query Engine for querying. |
47 | 56 | *
|
48 | 57 | * @author Brian Sam-Bodden
|
49 | 58 | */
|
50 |
| -public final class RedisChatMemory implements ChatMemory { |
| 59 | +public final class RedisChatMemory implements ChatMemory, ChatMemoryRepository { |
51 | 60 |
|
52 | 61 | private static final Logger logger = LoggerFactory.getLogger(RedisChatMemory.class);
|
53 | 62 |
|
@@ -113,26 +122,79 @@ public List<Message> get(String conversationId, int lastN) {
|
113 | 122 | Assert.isTrue(lastN > 0, "LastN must be greater than 0");
|
114 | 123 |
|
115 | 124 | String queryStr = String.format("@conversation_id:{%s}", RediSearchUtil.escape(conversationId));
|
| 125 | + // Use ascending order (oldest first) to match test expectations |
116 | 126 | Query query = new Query(queryStr).setSortBy("timestamp", true).limit(0, lastN);
|
117 | 127 |
|
118 | 128 | SearchResult result = jedis.ftSearch(config.getIndexName(), query);
|
119 | 129 |
|
| 130 | + if (logger.isDebugEnabled()) { |
| 131 | + logger.debug("Redis search for conversation {} returned {} results", conversationId, |
| 132 | + result.getDocuments().size()); |
| 133 | + result.getDocuments().forEach(doc -> { |
| 134 | + if (doc.get("$") != null) { |
| 135 | + JsonObject json = gson.fromJson(doc.getString("$"), JsonObject.class); |
| 136 | + logger.debug("Document: {}", json); |
| 137 | + } |
| 138 | + }); |
| 139 | + } |
| 140 | + |
120 | 141 | List<Message> messages = new ArrayList<>();
|
121 | 142 | result.getDocuments().forEach(doc -> {
|
122 | 143 | if (doc.get("$") != null) {
|
123 | 144 | JsonObject json = gson.fromJson(doc.getString("$"), JsonObject.class);
|
124 | 145 | String type = json.get("type").getAsString();
|
125 | 146 | String content = json.get("content").getAsString();
|
126 | 147 |
|
| 148 | + // Convert metadata from JSON to Map if present |
| 149 | + Map<String, Object> metadata = new HashMap<>(); |
| 150 | + if (json.has("metadata") && json.get("metadata").isJsonObject()) { |
| 151 | + JsonObject metadataJson = json.getAsJsonObject("metadata"); |
| 152 | + metadataJson.entrySet().forEach(entry -> { |
| 153 | + metadata.put(entry.getKey(), gson.fromJson(entry.getValue(), Object.class)); |
| 154 | + }); |
| 155 | + } |
| 156 | + |
127 | 157 | if (MessageType.ASSISTANT.toString().equals(type)) {
|
128 |
| - messages.add(new AssistantMessage(content)); |
| 158 | + // Handle tool calls if present |
| 159 | + List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>(); |
| 160 | + if (json.has("toolCalls") && json.get("toolCalls").isJsonArray()) { |
| 161 | + json.getAsJsonArray("toolCalls").forEach(element -> { |
| 162 | + JsonObject toolCallJson = element.getAsJsonObject(); |
| 163 | + toolCalls.add(new AssistantMessage.ToolCall( |
| 164 | + toolCallJson.has("id") ? toolCallJson.get("id").getAsString() : "", |
| 165 | + toolCallJson.has("type") ? toolCallJson.get("type").getAsString() : "", |
| 166 | + toolCallJson.has("name") ? toolCallJson.get("name").getAsString() : "", |
| 167 | + toolCallJson.has("arguments") ? toolCallJson.get("arguments").getAsString() : "")); |
| 168 | + }); |
| 169 | + } |
| 170 | + |
| 171 | + // Handle media if present |
| 172 | + List<Media> media = new ArrayList<>(); |
| 173 | + if (json.has("media") && json.get("media").isJsonArray()) { |
| 174 | + // Media deserialization would go here if needed |
| 175 | + // Left as empty list for simplicity |
| 176 | + } |
| 177 | + |
| 178 | + messages.add(new AssistantMessage(content, metadata, toolCalls, media)); |
129 | 179 | }
|
130 | 180 | else if (MessageType.USER.toString().equals(type)) {
|
131 |
| - messages.add(new UserMessage(content)); |
| 181 | + // Create a UserMessage with the builder to properly set metadata |
| 182 | + List<Media> userMedia = new ArrayList<>(); |
| 183 | + if (json.has("media") && json.get("media").isJsonArray()) { |
| 184 | + // Media deserialization would go here if needed |
| 185 | + } |
| 186 | + messages.add(UserMessage.builder().text(content).metadata(metadata).media(userMedia).build()); |
132 | 187 | }
|
| 188 | + // Add handling for other message types if needed |
133 | 189 | }
|
134 | 190 | });
|
135 | 191 |
|
| 192 | + if (logger.isDebugEnabled()) { |
| 193 | + logger.debug("Returning {} messages for conversation {}", messages.size(), conversationId); |
| 194 | + messages.forEach(message -> logger.debug("Message type: {}, content: {}", message.getMessageType(), |
| 195 | + message.getText())); |
| 196 | + } |
| 197 | + |
136 | 198 | return messages;
|
137 | 199 | }
|
138 | 200 |
|
@@ -179,14 +241,133 @@ private String createKey(String conversationId, long timestamp) {
|
179 | 241 | }
|
180 | 242 |
|
181 | 243 | private Map<String, Object> createMessageDocument(String conversationId, Message message) {
|
182 |
| - return Map.of("type", message.getMessageType().toString(), "content", message.getText(), "conversation_id", |
183 |
| - conversationId, "timestamp", Instant.now().toEpochMilli()); |
| 244 | + Map<String, Object> documentMap = new HashMap<>(); |
| 245 | + documentMap.put("type", message.getMessageType().toString()); |
| 246 | + documentMap.put("content", message.getText()); |
| 247 | + documentMap.put("conversation_id", conversationId); |
| 248 | + documentMap.put("timestamp", Instant.now().toEpochMilli()); |
| 249 | + |
| 250 | + // Store metadata/properties |
| 251 | + if (message.getMetadata() != null && !message.getMetadata().isEmpty()) { |
| 252 | + documentMap.put("metadata", message.getMetadata()); |
| 253 | + } |
| 254 | + |
| 255 | + // Handle tool calls for AssistantMessage |
| 256 | + if (message instanceof AssistantMessage assistantMessage && assistantMessage.hasToolCalls()) { |
| 257 | + documentMap.put("toolCalls", assistantMessage.getToolCalls()); |
| 258 | + } |
| 259 | + |
| 260 | + // Handle media content |
| 261 | + if (message instanceof MediaContent mediaContent && !mediaContent.getMedia().isEmpty()) { |
| 262 | + documentMap.put("media", mediaContent.getMedia()); |
| 263 | + } |
| 264 | + |
| 265 | + return documentMap; |
184 | 266 | }
|
185 | 267 |
|
186 | 268 | private String escapeKey(String key) {
|
187 | 269 | return key.replace(":", "\\:");
|
188 | 270 | }
|
189 | 271 |
|
| 272 | + // ChatMemoryRepository implementation |
| 273 | + |
| 274 | + /** |
| 275 | + * Finds all unique conversation IDs using Redis aggregation. This method is optimized |
| 276 | + * to perform the deduplication on the Redis server side. |
| 277 | + * @return a list of unique conversation IDs |
| 278 | + */ |
| 279 | + @Override |
| 280 | + public List<String> findConversationIds() { |
| 281 | + try { |
| 282 | + // Use Redis aggregation to get distinct conversation_ids |
| 283 | + AggregationBuilder aggregation = new AggregationBuilder("*") |
| 284 | + .groupBy("@conversation_id", Reducers.count().as("count")) |
| 285 | + .limit(0, config.getMaxConversationIds()); // Use configured limit |
| 286 | + |
| 287 | + AggregationResult result = jedis.ftAggregate(config.getIndexName(), aggregation); |
| 288 | + |
| 289 | + List<String> conversationIds = new ArrayList<>(); |
| 290 | + result.getResults().forEach(row -> { |
| 291 | + String conversationId = (String) row.get("conversation_id"); |
| 292 | + if (conversationId != null) { |
| 293 | + conversationIds.add(conversationId); |
| 294 | + } |
| 295 | + }); |
| 296 | + |
| 297 | + if (logger.isDebugEnabled()) { |
| 298 | + logger.debug("Found {} unique conversation IDs using Redis aggregation", conversationIds.size()); |
| 299 | + conversationIds.forEach(id -> logger.debug("Conversation ID: {}", id)); |
| 300 | + } |
| 301 | + |
| 302 | + return conversationIds; |
| 303 | + } |
| 304 | + catch (Exception e) { |
| 305 | + logger.warn("Error executing Redis aggregation for conversation IDs, falling back to client-side approach", |
| 306 | + e); |
| 307 | + return findConversationIdsLegacy(); |
| 308 | + } |
| 309 | + } |
| 310 | + |
| 311 | + /** |
| 312 | + * Fallback method to find conversation IDs if aggregation fails. This is less |
| 313 | + * efficient as it requires fetching all documents and deduplicating on the client |
| 314 | + * side. |
| 315 | + * @return a list of unique conversation IDs |
| 316 | + */ |
| 317 | + private List<String> findConversationIdsLegacy() { |
| 318 | + // Keep the current implementation as a fallback |
| 319 | + String queryStr = "*"; // Match all documents |
| 320 | + Query query = new Query(queryStr); |
| 321 | + query.limit(0, config.getMaxConversationIds()); // Use configured limit |
| 322 | + |
| 323 | + SearchResult result = jedis.ftSearch(config.getIndexName(), query); |
| 324 | + |
| 325 | + // Use a Set to deduplicate conversation IDs |
| 326 | + Set<String> conversationIds = new HashSet<>(); |
| 327 | + |
| 328 | + result.getDocuments().forEach(doc -> { |
| 329 | + if (doc.get("$") != null) { |
| 330 | + JsonObject json = gson.fromJson(doc.getString("$"), JsonObject.class); |
| 331 | + if (json.has("conversation_id")) { |
| 332 | + conversationIds.add(json.get("conversation_id").getAsString()); |
| 333 | + } |
| 334 | + } |
| 335 | + }); |
| 336 | + |
| 337 | + if (logger.isDebugEnabled()) { |
| 338 | + logger.debug("Found {} unique conversation IDs using legacy method", conversationIds.size()); |
| 339 | + } |
| 340 | + |
| 341 | + return new ArrayList<>(conversationIds); |
| 342 | + } |
| 343 | + |
| 344 | + /** |
| 345 | + * Finds all messages for a given conversation ID. Uses the configured maximum |
| 346 | + * messages per conversation limit to avoid exceeding Redis limits. |
| 347 | + * @param conversationId the conversation ID to find messages for |
| 348 | + * @return a list of messages for the conversation |
| 349 | + */ |
| 350 | + @Override |
| 351 | + public List<Message> findByConversationId(String conversationId) { |
| 352 | + // Reuse existing get method with the configured limit |
| 353 | + return get(conversationId, config.getMaxMessagesPerConversation()); |
| 354 | + } |
| 355 | + |
| 356 | + @Override |
| 357 | + public void saveAll(String conversationId, List<Message> messages) { |
| 358 | + // First clear any existing messages for this conversation |
| 359 | + clear(conversationId); |
| 360 | + |
| 361 | + // Then add all the new messages |
| 362 | + add(conversationId, messages); |
| 363 | + } |
| 364 | + |
| 365 | + @Override |
| 366 | + public void deleteByConversationId(String conversationId) { |
| 367 | + // Reuse existing clear method |
| 368 | + clear(conversationId); |
| 369 | + } |
| 370 | + |
190 | 371 | /**
|
191 | 372 | * Builder for RedisChatMemory configuration.
|
192 | 373 | */
|
|
0 commit comments