diff --git a/docs/features/llm.md b/docs/features/llm.md index 31648144..d38ce39d 100644 --- a/docs/features/llm.md +++ b/docs/features/llm.md @@ -95,7 +95,7 @@ SessionLocal = sessionmaker(bind=engine) client = ChatOpenAI(model="gpt-4o-mini") -mem = Memori(conn=SessionLocal).llm.register(client) +mem = Memori(conn=SessionLocal).llm.register(chatopenai=client) mem.attribution(entity_id="user_123", process_id="langchain_agent") response = client.invoke("Hello") diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index 41af8646..6ddf150c 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -485,7 +485,7 @@ def get_sqlite_connection(): return sqlite3.connect("memori.db") chat = ChatOpenAI() -mem = Memori(conn=get_sqlite_connection).llm.register(chat) +mem = Memori(conn=get_sqlite_connection).llm.register(chatopenai=chat) mem.attribution(entity_id="user-123", process_id="my-app") ``` diff --git a/memori/llm/_registry.py b/memori/llm/_registry.py index 95fcdab1..5d8eab30 100644 --- a/memori/llm/_registry.py +++ b/memori/llm/_registry.py @@ -39,6 +39,15 @@ def client(self, client_obj: Any, config) -> BaseClient: if matcher(client_obj): return client_class(config) + module = type(client_obj).__module__ + if module.startswith("langchain"): + class_name = type(client_obj).__name__ + param_hint = class_name.lower() + raise RuntimeError( + f"LangChain models require named parameters. " + f"Use: llm.register({param_hint}=client) instead of llm.register(client)" + ) + raise RuntimeError( f"Unsupported LLM client type: {type(client_obj).__module__}.{type(client_obj).__name__}" ) diff --git a/tests/llm/test_llm_registry.py b/tests/llm/test_llm_registry.py index fd0bd8a3..bdfc6f11 100644 --- a/tests/llm/test_llm_registry.py +++ b/tests/llm/test_llm_registry.py @@ -49,3 +49,22 @@ def test_llm_adapter_raises_for_unsupported_provider(): with pytest.raises(RuntimeError, match="Unsupported LLM provider"): Registry().adapter("mistral", "mistral") + + +def test_llm_client_raises_helpful_error_for_langchain(): + """Test that LangChain clients produce a helpful error message.""" + + class MockLangChainClient: + pass + + MockLangChainClient.__module__ = "langchain_openai.chat_models.base" + MockLangChainClient.__name__ = "ChatOpenAI" + + mock_client = MockLangChainClient() + mock_config = None + + with pytest.raises( + RuntimeError, + match=r"LangChain models require named parameters.*llm\.register\(chatopenai=client\)", + ): + Registry().client(mock_client, mock_config)