Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,10 @@ All notable changes to the Memori Python SDK will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Fixed

- Fixed multi-turn conversation ingestion for AzureOpenAI and OpenAI clients. Previously, only the first conversation turn was being recorded. Now `conversation_id` is resolved early in the request lifecycle, ensuring all conversation turns are properly ingested into the same conversation. (Fixes #83)

[3.0.0]: https://github.com/MemoriLabs/Memori/releases/tag/v3.0.0
72 changes: 69 additions & 3 deletions memori/llm/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import asyncio
import copy
import json
import logging
from typing import TYPE_CHECKING

from google.protobuf import json_format
Expand Down Expand Up @@ -374,12 +375,63 @@ def inject_recalled_facts(self, kwargs: dict) -> dict:
return kwargs

def inject_conversation_messages(self, kwargs: dict) -> dict:
if self.config.cache.conversation_id is None:
return kwargs

"""
Inject previous conversation messages into the current request.

This method ensures conversation_id is resolved early (before message injection)
to support multi-turn conversations. On subsequent API calls, it retrieves the
existing conversation_id from cache or creates it if needed, enabling proper
conversation continuity.

Note: This fix addresses issues where only the first conversation turn was
being recorded. By resolving conversation_id early, we ensure all turns
are properly ingested into the same conversation.
"""
if self.config.storage is None or self.config.storage.driver is None:
return kwargs

# Ensure session_id and conversation_id are available before injecting messages
# This allows us to retrieve previous messages even on subsequent calls.
# Fixes issue where conversation_id wasn't available early enough, causing
# only the first turn to be recorded (see GitHub issue #83).
if self.config.cache.conversation_id is None:
# First ensure session_id exists
if self.config.cache.session_id is None:
if self.config.entity_id is not None:
entity_id = self.config.storage.driver.entity.create(
self.config.entity_id
)
if entity_id is not None:
self.config.cache.entity_id = entity_id
if self.config.process_id is not None:
process_id = self.config.storage.driver.process.create(
self.config.process_id
)
if process_id is not None:
self.config.cache.process_id = process_id

session_id = self.config.storage.driver.session.create(
self.config.session_id,
self.config.cache.entity_id,
self.config.cache.process_id,
)
if session_id is not None:
self.config.cache.session_id = session_id

# Now try to get existing conversation for this session
if self.config.cache.session_id is not None:
# conversation.create returns existing conversation_id if within timeout,
# or creates a new one. This ensures we have a conversation_id.
existing_conv = self.config.storage.driver.conversation.create(
self.config.cache.session_id,
self.config.session_timeout_minutes,
)
if existing_conv is not None:
self.config.cache.conversation_id = existing_conv
# If still None, we'll create it in the Writer later
if self.config.cache.conversation_id is None:
return kwargs

messages = self.config.storage.driver.conversation.messages.read(
self.config.cache.conversation_id
)
Expand Down Expand Up @@ -520,6 +572,8 @@ def _strip_memori_context_from_messages(self, messages: list) -> list:
def handle_post_response(self, kwargs, start_time, raw_response):
from memori.memory._manager import Manager as MemoryManager

logger = logging.getLogger(__name__)

if "model" in kwargs:
self.config.llm.version = kwargs["model"]

Expand All @@ -533,6 +587,18 @@ def handle_post_response(self, kwargs, start_time, raw_response):
self._format_response(self.get_response_content(raw_response)),
)

conv_id = self.config.cache.conversation_id
msg_count = len(
payload.get("conversation", {}).get("query", {}).get("messages", [])
)
resp_count = len(
payload.get("conversation", {}).get("response", {}).get("choices", [])
)
logger.debug(
f"Ingesting conversation turn: conversation_id={conv_id}, "
f"messages_count={msg_count}, responses_count={resp_count}"
)

MemoryManager(self.config).execute(payload)

if self.config.augmentation is not None:
Expand Down
22 changes: 22 additions & 0 deletions memori/memory/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""

import json
import logging
import time

from sqlalchemy.exc import OperationalError
Expand All @@ -19,6 +20,8 @@
MAX_RETRIES = 3
RETRY_BACKOFF_BASE = 0.1

logger = logging.getLogger(__name__)


class Writer:
def __init__(self, config: Config):
Expand Down Expand Up @@ -74,13 +77,22 @@ def _execute_transaction(self, payload: dict) -> None:
self.config.cache.process_id,
)

# Ensure conversation_id exists - may have been set earlier in
# inject_conversation_messages. If not, create/get it now.
# conversation.create is idempotent and returns existing conversation
# if within timeout, so it's safe to call multiple times.
self._ensure_cached_id(
"conversation_id",
self.config.storage.driver.conversation.create,
self.config.cache.session_id,
self.config.session_timeout_minutes,
)

logger.debug(
f"Writing to conversation_id={self.config.cache.conversation_id}, "
f"session_id={self.config.cache.session_id}"
)

llm = LlmRegistry().adapter(
payload["conversation"]["client"]["provider"],
payload["conversation"]["client"]["title"],
Expand All @@ -93,6 +105,11 @@ def _execute_transaction(self, payload: dict) -> None:
content = message["content"]
if isinstance(content, dict | list):
content = json.dumps(content)
conv_id = self.config.cache.conversation_id
logger.debug(
f"Writing {message['role']} message to "
f"conversation_id={conv_id}"
)
self.config.storage.driver.conversation.message.create(
self.config.cache.conversation_id,
message["role"],
Expand All @@ -103,6 +120,11 @@ def _execute_transaction(self, payload: dict) -> None:
responses = llm.get_formatted_response(payload)
if responses:
for response in responses:
conv_id = self.config.cache.conversation_id
logger.debug(
f"Writing {response['role']} response to "
f"conversation_id={conv_id}"
)
self.config.storage.driver.conversation.message.create(
self.config.cache.conversation_id,
response["role"],
Expand Down
93 changes: 93 additions & 0 deletions tests/memory/test_memory_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,96 @@ def test_execute_skips_system_messages(config, mocker):
assert calls[0][0][3] == "Hello"
assert calls[1][0][1] == "assistant"
assert calls[1][0][3] == "Hi there!"


def test_execute_multiple_turns_ingests_all_messages(config, mocker):
"""Test that multiple conversation turns properly ingest all user and assistant messages."""
from unittest.mock import Mock

# Mock the conversation.create to return a consistent conversation_id
conversation_id = 123
config.storage.driver.conversation.create.return_value = conversation_id
config.cache.conversation_id = None # Start with no conversation_id

# First turn: user message + assistant response
mock_messages_turn1 = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
config.storage.adapter.execute.return_value.mappings.return_value.fetchall.return_value = (
mock_messages_turn1
)

payload_turn1 = {
"conversation": {
"client": {"provider": None, "title": OPENAI_LLM_PROVIDER},
"query": {"messages": [{"content": "Hello", "role": "user"}]},
"response": {
"choices": [{"message": {"content": "Hi there!", "role": "assistant"}}]
},
}
}

Writer(config).execute(payload_turn1)

# Verify first turn was written
assert config.cache.conversation_id == conversation_id
assert config.storage.driver.conversation.message.create.call_count == 2

calls_turn1 = config.storage.driver.conversation.message.create.call_args_list
assert calls_turn1[0][0][1] == "user"
assert calls_turn1[0][0][3] == "Hello"
assert calls_turn1[1][0][1] == "assistant"
assert calls_turn1[1][0][3] == "Hi there!"

# Reset mocks for second turn
config.storage.driver.conversation.message.create.reset_mock()

# Second turn: new user message + assistant response
# The conversation should have previous messages injected, but only new messages should be written
mock_messages_turn2 = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "What's the weather?"},
{"role": "assistant", "content": "I don't have access to weather data."},
]
config.storage.adapter.execute.return_value.mappings.return_value.fetchall.return_value = (
mock_messages_turn2
)

# Simulate that previous messages were injected (so they're excluded from writing)
payload_turn2 = {
"conversation": {
"client": {"provider": None, "title": OPENAI_LLM_PROVIDER},
"query": {
"messages": [
{"content": "Hello", "role": "user"},
{"content": "Hi there!", "role": "assistant"},
{"content": "What's the weather?", "role": "user"},
],
"_memori_injected_count": 2, # First 2 messages were injected
},
"response": {
"choices": [
{
"message": {
"content": "I don't have access to weather data.",
"role": "assistant",
}
}
]
},
}
}

Writer(config).execute(payload_turn2)

# Verify second turn was written (only new messages, not injected ones)
assert config.cache.conversation_id == conversation_id
assert config.storage.driver.conversation.message.create.call_count == 2

calls_turn2 = config.storage.driver.conversation.message.create.call_args_list
assert calls_turn2[0][0][1] == "user"
assert calls_turn2[0][0][3] == "What's the weather?"
assert calls_turn2[1][0][1] == "assistant"
assert calls_turn2[1][0][3] == "I don't have access to weather data."