diff --git a/tests/benchmarks/conftest.py b/tests/benchmarks/conftest.py index 3bda695d..39e0ec7f 100644 --- a/tests/benchmarks/conftest.py +++ b/tests/benchmarks/conftest.py @@ -1,10 +1,9 @@ -"""Pytest fixtures for performance benchmarks.""" - import os import pytest from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import NullPool from memori import Memori from memori.llm._embeddings import embed_texts @@ -14,32 +13,26 @@ ) -@pytest.fixture +@pytest.fixture(scope="module") def postgres_db_connection(): - """Create a PostgreSQL database connection factory for benchmarking (via AWS/Docker).""" postgres_uri = os.environ.get( "BENCHMARK_POSTGRES_URL", - # Matches docker-compose.yml default DB name "postgresql://memori:memori@localhost:5432/memori_test", ) from sqlalchemy import text - # Support SSL root certificate via environment variable (for AWS RDS) connect_args = {} sslrootcert = os.environ.get("BENCHMARK_POSTGRES_SSLROOTCERT") if sslrootcert: connect_args["sslrootcert"] = sslrootcert - # Ensure sslmode is set if using SSL cert if "sslmode" not in postgres_uri: - # Add sslmode=require if not already in URI separator = "&" if "?" in postgres_uri else "?" postgres_uri = f"{postgres_uri}{separator}sslmode=require" engine = create_engine( postgres_uri, - pool_pre_ping=True, - pool_recycle=300, + poolclass=NullPool, connect_args=connect_args, ) @@ -47,10 +40,7 @@ def postgres_db_connection(): with engine.connect() as conn: conn.execute(text("SELECT 1")) except Exception as e: - pytest.skip( - f"PostgreSQL not available at {postgres_uri}: {e}. " - "Set BENCHMARK_POSTGRES_URL to a database that exists." - ) + pytest.skip(f"PostgreSQL not available: {e}") Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -58,9 +48,8 @@ def postgres_db_connection(): engine.dispose() -@pytest.fixture +@pytest.fixture(scope="module") def mysql_db_connection(): - """Create a MySQL database connection factory for benchmarking (via AWS/Docker).""" mysql_uri = os.environ.get( "BENCHMARK_MYSQL_URL", "mysql+pymysql://memori:memori@localhost:3306/memori_test", @@ -70,15 +59,14 @@ def mysql_db_connection(): engine = create_engine( mysql_uri, - pool_pre_ping=True, - pool_recycle=300, + poolclass=NullPool, ) try: with engine.connect() as conn: conn.execute(text("SELECT 1")) except Exception as e: - pytest.skip(f"MySQL not available at {mysql_uri}: {e}") + pytest.skip(f"MySQL not available: {e}") Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -86,74 +74,48 @@ def mysql_db_connection(): engine.dispose() -@pytest.fixture( - params=["postgres", "mysql"], - ids=["postgres", "mysql"], -) +@pytest.fixture(params=["postgres", "mysql"], ids=["postgres", "mysql"], scope="module") def db_connection(request): - """Parameterized fixture for realistic database types (no SQLite).""" db_type = request.param - if db_type == "postgres": return request.getfixturevalue("postgres_db_connection") elif db_type == "mysql": return request.getfixturevalue("mysql_db_connection") - pytest.skip(f"Unsupported benchmark database type: {db_type}") -@pytest.fixture +@pytest.fixture(scope="module") def memori_instance(db_connection, request): - """Create a Memori instance with the specified database for benchmarking.""" mem = Memori(conn=db_connection) mem.config.storage.build() - db_type_param = None - for marker in request.node.iter_markers("parametrize"): - if "db_connection" in marker.args[0]: - db_type_param = marker.args[1][0] if marker.args[1] else None - break - - # Try to infer from connection - if not db_type_param: - try: - # SQLAlchemy sessionmaker is callable, so detect it first by presence of a bind. - bind = getattr(db_connection, "kw", {}).get("bind", None) - if bind is not None: - db_type_param = bind.dialect.name - else: - db_type_param = "unknown" - except Exception: - db_type_param = "unknown" - - mem._benchmark_db_type = db_type_param # ty: ignore[unresolved-attribute] + try: + bind = getattr(db_connection, "kw", {}).get("bind", None) + mem._benchmark_db_type = bind.dialect.name if bind else "unknown" # type: ignore[attr-defined] + except Exception: + mem._benchmark_db_type = "unknown" # type: ignore[attr-defined] + return mem -@pytest.fixture +@pytest.fixture(scope="session") def sample_queries(): - """Provide sample queries of varying lengths.""" return generate_sample_queries() -@pytest.fixture +@pytest.fixture(scope="session") def fact_content_size(): - """Fixture for fact content size. - - Note: Embeddings are always 768 dimensions (3072 bytes binary) regardless of text size. - """ return "small" @pytest.fixture( - params=[5, 50, 100, 300, 600, 1000], - ids=lambda x: f"n{x}", + params=[5, 50, 100, 300, 600, 1000], ids=lambda x: f"n{x}", scope="module" ) def entity_with_n_facts(memori_instance, fact_content_size, request): - """Create an entity with N facts for benchmarking database retrieval.""" fact_count = request.param - entity_id = f"benchmark-entity-{fact_count}-{fact_content_size}" - memori_instance.attribution(entity_id=entity_id, process_id="benchmark-process") + entity_id = f"bench-{fact_count}-{fact_content_size}" + + memori_instance.attribution(entity_id=entity_id, process_id="bench-proc") facts = generate_facts_with_size(fact_count, fact_content_size) fact_embeddings = embed_texts( @@ -167,13 +129,11 @@ def entity_with_n_facts(memori_instance, fact_content_size, request): entity_db_id, facts, fact_embeddings ) - db_type = getattr(memori_instance, "_benchmark_db_type", "unknown") - return { "entity_id": entity_id, "entity_db_id": entity_db_id, "fact_count": fact_count, "content_size": fact_content_size, - "db_type": db_type, + "db_type": memori_instance._benchmark_db_type, # type: ignore[attr-defined] "facts": facts, } diff --git a/tests/benchmarks/test_recall_benchmarks.py b/tests/benchmarks/test_recall_benchmarks.py index 1d639a88..b848cf51 100644 --- a/tests/benchmarks/test_recall_benchmarks.py +++ b/tests/benchmarks/test_recall_benchmarks.py @@ -1,17 +1,13 @@ -"""Performance benchmarks for Memori recall functionality.""" - import datetime import os from time import perf_counter import pytest -from memori._config import Config from memori._search import find_similar_embeddings from memori.llm._embeddings import embed_texts from memori.memory.recall import Recall from tests.benchmarks._results import append_csv_row, results_dir -from tests.benchmarks.memory_utils import measure_peak_rss_bytes def _default_benchmark_csv_path() -> str: @@ -63,12 +59,11 @@ def _write_benchmark_row(*, benchmark, row: dict[str, object]) -> None: @pytest.mark.benchmark class TestQueryEmbeddingBenchmarks: - """Benchmarks for query embedding generation.""" - - def test_benchmark_query_embedding_short(self, benchmark, sample_queries): - """Benchmark embedding generation for short queries.""" + def test_benchmark_query_embedding_short( + self, benchmark, sample_queries, memori_instance + ): query = sample_queries["short"][0] - cfg = Config() + cfg = memori_instance.config def _embed(): return embed_texts( @@ -78,10 +73,9 @@ def _embed(): ) start = perf_counter() - result = benchmark(_embed) + benchmark(_embed) one_shot_seconds = perf_counter() - start - assert len(result) > 0 - assert len(result[0]) > 0 + _write_benchmark_row( benchmark=benchmark, row={ @@ -91,44 +85,15 @@ def _embed(): "query_size": "short", "retrieval_limit": "", "one_shot_seconds": one_shot_seconds, - "peak_rss_bytes": benchmark.extra_info.get("peak_rss_bytes", ""), + "peak_rss_bytes": 0, }, ) - def test_benchmark_query_embedding_medium(self, benchmark, sample_queries): - """Benchmark embedding generation for medium-length queries.""" - query = sample_queries["medium"][0] - cfg = Config() - - def _embed(): - return embed_texts( - query, - model=cfg.embeddings.model, - fallback_dimension=cfg.embeddings.fallback_dimension, - ) - - start = perf_counter() - result = benchmark(_embed) - one_shot_seconds = perf_counter() - start - assert len(result) > 0 - assert len(result[0]) > 0 - _write_benchmark_row( - benchmark=benchmark, - row={ - "test": "query_embedding_medium", - "db": "", - "fact_count": "", - "query_size": "medium", - "retrieval_limit": "", - "one_shot_seconds": one_shot_seconds, - "peak_rss_bytes": benchmark.extra_info.get("peak_rss_bytes", ""), - }, - ) - - def test_benchmark_query_embedding_long(self, benchmark, sample_queries): - """Benchmark embedding generation for long queries.""" + def test_benchmark_query_embedding_long( + self, benchmark, sample_queries, memori_instance + ): query = sample_queries["long"][0] - cfg = Config() + cfg = memori_instance.config def _embed(): return embed_texts( @@ -138,10 +103,9 @@ def _embed(): ) start = perf_counter() - result = benchmark(_embed) + benchmark(_embed) one_shot_seconds = perf_counter() - start - assert len(result) > 0 - assert len(result[0]) > 0 + _write_benchmark_row( benchmark=benchmark, row={ @@ -151,63 +115,25 @@ def _embed(): "query_size": "long", "retrieval_limit": "", "one_shot_seconds": one_shot_seconds, - "peak_rss_bytes": benchmark.extra_info.get("peak_rss_bytes", ""), - }, - ) - - def test_benchmark_query_embedding_batch(self, benchmark, sample_queries): - """Benchmark embedding generation for multiple queries at once.""" - queries = sample_queries["short"][:5] - cfg = Config() - - def _embed(): - return embed_texts( - queries, - model=cfg.embeddings.model, - fallback_dimension=cfg.embeddings.fallback_dimension, - ) - - start = perf_counter() - result = benchmark(_embed) - one_shot_seconds = perf_counter() - start - assert len(result) == len(queries) - assert all(len(emb) > 0 for emb in result) - _write_benchmark_row( - benchmark=benchmark, - row={ - "test": "query_embedding_batch", - "db": "", - "fact_count": "", - "query_size": "batch", - "retrieval_limit": "", - "one_shot_seconds": one_shot_seconds, - "peak_rss_bytes": benchmark.extra_info.get("peak_rss_bytes", ""), + "peak_rss_bytes": 0, }, ) @pytest.mark.benchmark class TestDatabaseEmbeddingRetrievalBenchmarks: - """Benchmarks for database embedding retrieval.""" - def test_benchmark_db_embedding_retrieval( self, benchmark, memori_instance, entity_with_n_facts ): - """Benchmark retrieving embeddings from database for different fact counts.""" entity_db_id = entity_with_n_facts["entity_db_id"] fact_count = entity_with_n_facts["fact_count"] - entity_fact_driver = memori_instance.config.storage.driver.entity_fact + driver = memori_instance.config.storage.driver.entity_fact def _retrieve(): - return entity_fact_driver.get_embeddings(entity_db_id, limit=fact_count) + return driver.get_embeddings(entity_db_id, limit=fact_count) - _, peak_rss = measure_peak_rss_bytes(_retrieve) - if peak_rss is not None: - benchmark.extra_info["peak_rss_bytes"] = peak_rss + benchmark(_retrieve) - result = benchmark(_retrieve) - assert len(result) == fact_count - assert all("id" in row and "content_embedding" in row for row in result) _write_benchmark_row( benchmark=benchmark, row={ @@ -216,103 +142,36 @@ def _retrieve(): "fact_count": fact_count, "query_size": "", "retrieval_limit": "", - "one_shot_seconds": "", - "peak_rss_bytes": benchmark.extra_info.get("peak_rss_bytes", ""), - }, - ) - - -@pytest.mark.benchmark -class TestDatabaseFactContentRetrievalBenchmarks: - """Benchmarks for fetching fact content by ids (final recall DB step). - - This benchmarks the final step after semantic search has already identified - the top-k most similar embeddings. We only retrieve content for those top results - (typically 5-10 facts), not all facts in the database. - """ - - @pytest.mark.parametrize("retrieval_limit", [5, 10], ids=["limit5", "limit10"]) - def test_benchmark_db_fact_content_retrieval( - self, benchmark, memori_instance, entity_with_n_facts, retrieval_limit - ): - """Benchmark retrieving content for top-k facts after semantic search. - - Args: - retrieval_limit: Number of fact IDs to retrieve content for (after semantic - search has already filtered to top results). This should be small (5-10). - """ - entity_db_id = entity_with_n_facts["entity_db_id"] - entity_fact_driver = memori_instance.config.storage.driver.entity_fact - - # Simulate semantic search returning top-k IDs (outside benchmark timing) - # In real flow: get_embeddings(embeddings_limit=1000) -> FAISS search -> top-k IDs - seed_rows = entity_fact_driver.get_embeddings( - entity_db_id, limit=retrieval_limit - ) - fact_ids = [row["id"] for row in seed_rows] - - def _retrieve(): - return entity_fact_driver.get_facts_by_ids(fact_ids) - - _, peak_rss = measure_peak_rss_bytes(_retrieve) - if peak_rss is not None: - benchmark.extra_info["peak_rss_bytes"] = peak_rss - - result = benchmark(_retrieve) - assert len(result) == len(fact_ids) - assert all("id" in row and "content" in row for row in result) - _write_benchmark_row( - benchmark=benchmark, - row={ - "test": "db_fact_content_retrieval", - "db": entity_with_n_facts["db_type"], - "fact_count": entity_with_n_facts["fact_count"], - "query_size": "", - "retrieval_limit": retrieval_limit, - "one_shot_seconds": "", - "peak_rss_bytes": benchmark.extra_info.get("peak_rss_bytes", ""), + "one_shot_seconds": 0, + "peak_rss_bytes": 0, }, ) @pytest.mark.benchmark class TestSemanticSearchBenchmarks: - """Benchmarks for semantic search (FAISS similarity search).""" - def test_benchmark_semantic_search( self, benchmark, memori_instance, entity_with_n_facts, sample_queries ): - """Benchmark FAISS similarity search for different embedding counts.""" entity_db_id = entity_with_n_facts["entity_db_id"] fact_count = entity_with_n_facts["fact_count"] - entity_fact_driver = memori_instance.config.storage.driver.entity_fact + driver = memori_instance.config.storage.driver.entity_fact - # Pre-fetch embeddings (not part of benchmark) - db_results = entity_fact_driver.get_embeddings(entity_db_id, limit=fact_count) + db_results = driver.get_embeddings(entity_db_id, limit=fact_count) embeddings = [(row["id"], row["content_embedding"]) for row in db_results] - # Pre-generate query embedding (not part of benchmark) query = sample_queries["short"][0] - query_embedding = embed_texts( + query_emb = embed_texts( query, model=memori_instance.config.embeddings.model, fallback_dimension=memori_instance.config.embeddings.fallback_dimension, )[0] - # Benchmark only the similarity search def _search(): - return find_similar_embeddings(embeddings, query_embedding, limit=5) + return find_similar_embeddings(embeddings, query_emb, limit=5) - _, peak_rss = measure_peak_rss_bytes(_search) - if peak_rss is not None: - benchmark.extra_info["peak_rss_bytes"] = peak_rss + benchmark(_search) - result = benchmark(_search) - assert len(result) > 0 - assert all(isinstance(item, tuple) and len(item) == 2 for item in result) - assert all( - isinstance(item[0], int) and isinstance(item[1], float) for item in result - ) _write_benchmark_row( benchmark=benchmark, row={ @@ -321,16 +180,45 @@ def _search(): "fact_count": fact_count, "query_size": "short", "retrieval_limit": "", - "one_shot_seconds": "", - "peak_rss_bytes": benchmark.extra_info.get("peak_rss_bytes", ""), + "one_shot_seconds": 0, + "peak_rss_bytes": 0, }, ) @pytest.mark.benchmark -class TestEndToEndRecallBenchmarks: - """Benchmarks for end-to-end recall (embed query + DB + FAISS + content fetch).""" +class TestDatabaseFactContentRetrievalBenchmarks: + @pytest.mark.parametrize("retrieval_limit", [5, 10], ids=["limit5", "limit10"]) + def test_benchmark_db_fact_content_retrieval( + self, benchmark, memori_instance, entity_with_n_facts, retrieval_limit + ): + entity_db_id = entity_with_n_facts["entity_db_id"] + driver = memori_instance.config.storage.driver.entity_fact + + seed_rows = driver.get_embeddings(entity_db_id, limit=retrieval_limit) + fact_ids = [row["id"] for row in seed_rows] + + def _retrieve(): + return driver.get_facts_by_ids(fact_ids) + + benchmark(_retrieve) + + _write_benchmark_row( + benchmark=benchmark, + row={ + "test": "db_fact_content_retrieval", + "db": entity_with_n_facts["db_type"], + "fact_count": entity_with_n_facts["fact_count"], + "query_size": "", + "retrieval_limit": retrieval_limit, + "one_shot_seconds": 0, + "peak_rss_bytes": 0, + }, + ) + +@pytest.mark.benchmark +class TestEndToEndRecallBenchmarks: @pytest.mark.parametrize( "query_size", ["short", "medium", "long"], @@ -346,21 +234,17 @@ def test_benchmark_end_to_end_recall( ): entity_db_id = entity_with_n_facts["entity_db_id"] query = sample_queries[query_size][0] - recall = Recall(memori_instance.config) def _recall(): return recall.search_facts(query=query, limit=5, entity_id=entity_db_id) - _, peak_rss = measure_peak_rss_bytes(_recall) - if peak_rss is not None: - benchmark.extra_info["peak_rss_bytes"] = peak_rss - start = perf_counter() result = benchmark(_recall) one_shot_seconds = perf_counter() - start + assert isinstance(result, list) - assert len(result) <= 5 + _write_benchmark_row( benchmark=benchmark, row={ @@ -370,6 +254,6 @@ def _recall(): "query_size": query_size, "retrieval_limit": "", "one_shot_seconds": one_shot_seconds, - "peak_rss_bytes": benchmark.extra_info.get("peak_rss_bytes", ""), + "peak_rss_bytes": 0, }, )