diff --git a/CHANGELOG.md b/CHANGELOG.md index 3049aa86..d8a18552 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,5 +5,10 @@ All notable changes to the Memori Python SDK will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Fixed + +- Fixed multi-turn conversation ingestion for AzureOpenAI and OpenAI clients. Previously, only the first conversation turn was being recorded. Now `conversation_id` is resolved early in the request lifecycle, ensuring all conversation turns are properly ingested into the same conversation. (Fixes #83) [3.0.0]: https://github.com/MemoriLabs/Memori/releases/tag/v3.0.0 diff --git a/memori/llm/_base.py b/memori/llm/_base.py index a94aa51a..ba405205 100644 --- a/memori/llm/_base.py +++ b/memori/llm/_base.py @@ -11,6 +11,7 @@ import asyncio import copy import json +import logging from typing import TYPE_CHECKING from google.protobuf import json_format @@ -374,12 +375,63 @@ def inject_recalled_facts(self, kwargs: dict) -> dict: return kwargs def inject_conversation_messages(self, kwargs: dict) -> dict: - if self.config.cache.conversation_id is None: - return kwargs - + """ + Inject previous conversation messages into the current request. + + This method ensures conversation_id is resolved early (before message injection) + to support multi-turn conversations. On subsequent API calls, it retrieves the + existing conversation_id from cache or creates it if needed, enabling proper + conversation continuity. + + Note: This fix addresses issues where only the first conversation turn was + being recorded. By resolving conversation_id early, we ensure all turns + are properly ingested into the same conversation. + """ if self.config.storage is None or self.config.storage.driver is None: return kwargs + # Ensure session_id and conversation_id are available before injecting messages + # This allows us to retrieve previous messages even on subsequent calls. + # Fixes issue where conversation_id wasn't available early enough, causing + # only the first turn to be recorded (see GitHub issue #83). + if self.config.cache.conversation_id is None: + # First ensure session_id exists + if self.config.cache.session_id is None: + if self.config.entity_id is not None: + entity_id = self.config.storage.driver.entity.create( + self.config.entity_id + ) + if entity_id is not None: + self.config.cache.entity_id = entity_id + if self.config.process_id is not None: + process_id = self.config.storage.driver.process.create( + self.config.process_id + ) + if process_id is not None: + self.config.cache.process_id = process_id + + session_id = self.config.storage.driver.session.create( + self.config.session_id, + self.config.cache.entity_id, + self.config.cache.process_id, + ) + if session_id is not None: + self.config.cache.session_id = session_id + + # Now try to get existing conversation for this session + if self.config.cache.session_id is not None: + # conversation.create returns existing conversation_id if within timeout, + # or creates a new one. This ensures we have a conversation_id. + existing_conv = self.config.storage.driver.conversation.create( + self.config.cache.session_id, + self.config.session_timeout_minutes, + ) + if existing_conv is not None: + self.config.cache.conversation_id = existing_conv + # If still None, we'll create it in the Writer later + if self.config.cache.conversation_id is None: + return kwargs + messages = self.config.storage.driver.conversation.messages.read( self.config.cache.conversation_id ) @@ -520,6 +572,8 @@ def _strip_memori_context_from_messages(self, messages: list) -> list: def handle_post_response(self, kwargs, start_time, raw_response): from memori.memory._manager import Manager as MemoryManager + logger = logging.getLogger(__name__) + if "model" in kwargs: self.config.llm.version = kwargs["model"] @@ -533,6 +587,18 @@ def handle_post_response(self, kwargs, start_time, raw_response): self._format_response(self.get_response_content(raw_response)), ) + conv_id = self.config.cache.conversation_id + msg_count = len( + payload.get("conversation", {}).get("query", {}).get("messages", []) + ) + resp_count = len( + payload.get("conversation", {}).get("response", {}).get("choices", []) + ) + logger.debug( + f"Ingesting conversation turn: conversation_id={conv_id}, " + f"messages_count={msg_count}, responses_count={resp_count}" + ) + MemoryManager(self.config).execute(payload) if self.config.augmentation is not None: diff --git a/memori/memory/_writer.py b/memori/memory/_writer.py index efeb0612..04277203 100644 --- a/memori/memory/_writer.py +++ b/memori/memory/_writer.py @@ -9,6 +9,7 @@ """ import json +import logging import time from sqlalchemy.exc import OperationalError @@ -19,6 +20,8 @@ MAX_RETRIES = 3 RETRY_BACKOFF_BASE = 0.1 +logger = logging.getLogger(__name__) + class Writer: def __init__(self, config: Config): @@ -74,6 +77,10 @@ def _execute_transaction(self, payload: dict) -> None: self.config.cache.process_id, ) + # Ensure conversation_id exists - may have been set earlier in + # inject_conversation_messages. If not, create/get it now. + # conversation.create is idempotent and returns existing conversation + # if within timeout, so it's safe to call multiple times. self._ensure_cached_id( "conversation_id", self.config.storage.driver.conversation.create, @@ -81,6 +88,11 @@ def _execute_transaction(self, payload: dict) -> None: self.config.session_timeout_minutes, ) + logger.debug( + f"Writing to conversation_id={self.config.cache.conversation_id}, " + f"session_id={self.config.cache.session_id}" + ) + llm = LlmRegistry().adapter( payload["conversation"]["client"]["provider"], payload["conversation"]["client"]["title"], @@ -93,6 +105,11 @@ def _execute_transaction(self, payload: dict) -> None: content = message["content"] if isinstance(content, dict | list): content = json.dumps(content) + conv_id = self.config.cache.conversation_id + logger.debug( + f"Writing {message['role']} message to " + f"conversation_id={conv_id}" + ) self.config.storage.driver.conversation.message.create( self.config.cache.conversation_id, message["role"], @@ -103,6 +120,11 @@ def _execute_transaction(self, payload: dict) -> None: responses = llm.get_formatted_response(payload) if responses: for response in responses: + conv_id = self.config.cache.conversation_id + logger.debug( + f"Writing {response['role']} response to " + f"conversation_id={conv_id}" + ) self.config.storage.driver.conversation.message.create( self.config.cache.conversation_id, response["role"], diff --git a/tests/memory/test_memory_writer.py b/tests/memory/test_memory_writer.py index 663558ed..89b98552 100644 --- a/tests/memory/test_memory_writer.py +++ b/tests/memory/test_memory_writer.py @@ -123,3 +123,96 @@ def test_execute_skips_system_messages(config, mocker): assert calls[0][0][3] == "Hello" assert calls[1][0][1] == "assistant" assert calls[1][0][3] == "Hi there!" + + +def test_execute_multiple_turns_ingests_all_messages(config, mocker): + """Test that multiple conversation turns properly ingest all user and assistant messages.""" + from unittest.mock import Mock + + # Mock the conversation.create to return a consistent conversation_id + conversation_id = 123 + config.storage.driver.conversation.create.return_value = conversation_id + config.cache.conversation_id = None # Start with no conversation_id + + # First turn: user message + assistant response + mock_messages_turn1 = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + config.storage.adapter.execute.return_value.mappings.return_value.fetchall.return_value = ( + mock_messages_turn1 + ) + + payload_turn1 = { + "conversation": { + "client": {"provider": None, "title": OPENAI_LLM_PROVIDER}, + "query": {"messages": [{"content": "Hello", "role": "user"}]}, + "response": { + "choices": [{"message": {"content": "Hi there!", "role": "assistant"}}] + }, + } + } + + Writer(config).execute(payload_turn1) + + # Verify first turn was written + assert config.cache.conversation_id == conversation_id + assert config.storage.driver.conversation.message.create.call_count == 2 + + calls_turn1 = config.storage.driver.conversation.message.create.call_args_list + assert calls_turn1[0][0][1] == "user" + assert calls_turn1[0][0][3] == "Hello" + assert calls_turn1[1][0][1] == "assistant" + assert calls_turn1[1][0][3] == "Hi there!" + + # Reset mocks for second turn + config.storage.driver.conversation.message.create.reset_mock() + + # Second turn: new user message + assistant response + # The conversation should have previous messages injected, but only new messages should be written + mock_messages_turn2 = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "What's the weather?"}, + {"role": "assistant", "content": "I don't have access to weather data."}, + ] + config.storage.adapter.execute.return_value.mappings.return_value.fetchall.return_value = ( + mock_messages_turn2 + ) + + # Simulate that previous messages were injected (so they're excluded from writing) + payload_turn2 = { + "conversation": { + "client": {"provider": None, "title": OPENAI_LLM_PROVIDER}, + "query": { + "messages": [ + {"content": "Hello", "role": "user"}, + {"content": "Hi there!", "role": "assistant"}, + {"content": "What's the weather?", "role": "user"}, + ], + "_memori_injected_count": 2, # First 2 messages were injected + }, + "response": { + "choices": [ + { + "message": { + "content": "I don't have access to weather data.", + "role": "assistant", + } + } + ] + }, + } + } + + Writer(config).execute(payload_turn2) + + # Verify second turn was written (only new messages, not injected ones) + assert config.cache.conversation_id == conversation_id + assert config.storage.driver.conversation.message.create.call_count == 2 + + calls_turn2 = config.storage.driver.conversation.message.create.call_args_list + assert calls_turn2[0][0][1] == "user" + assert calls_turn2[0][0][3] == "What's the weather?" + assert calls_turn2[1][0][1] == "assistant" + assert calls_turn2[1][0][3] == "I don't have access to weather data."