Skip to content

Commit 8d26e24

Browse files
authored
Generalizing tests for smarter LLMs (#1149)
1 parent a9c177d commit 8d26e24

File tree

6 files changed

+1803
-1752
lines changed

6 files changed

+1803
-1752
lines changed

tests/cassettes/test_partitioning_fn_docs[False].yaml

Lines changed: 635 additions & 617 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/cassettes/test_partitioning_fn_docs[True].yaml

Lines changed: 1142 additions & 1120 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def agent_stub_session() -> PQASession:
123123
# > are already imported: paperqa
124124
from paperqa.types import PQASession
125125

126-
return PQASession(question="What is is a self-explanatory model?")
126+
return PQASession(question="What is a self-explanatory model?")
127127

128128

129129
@pytest.fixture

tests/test_agents.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -516,17 +516,23 @@ async def test_propagate_options(agent_test_settings: Settings) -> None:
516516
agent_test_settings.answer.evidence_skip_summary = True
517517

518518
response = await agent_query(
519-
query="What is is a self-explanatory model?",
519+
query="What is a self-explanatory model?",
520520
settings=agent_test_settings,
521521
agent_type=FAKE_AGENT_TYPE,
522522
)
523523
assert response.status == AgentStatus.SUCCESS, "Agent did not succeed"
524524
result = response.session
525525
assert len(result.answer) > 200, "Answer did not return any results"
526526
assert "###" in result.answer, "Answer did not propagate system prompt"
527+
assert len(result.contexts) >= 2, "Test expects a few contexts"
527528
# Subtract 2 to allow tolerance for chunks with leading/trailing whitespace
529+
num_contexts_sufficient_length = sum(
530+
len(c.context) >= agent_test_settings.parsing.chunk_size - 2
531+
for c in result.contexts
532+
)
533+
# Check most contexts have the expected length
528534
assert (
529-
len(result.contexts[0].context) >= agent_test_settings.parsing.chunk_size - 2
535+
num_contexts_sufficient_length >= len(result.contexts) - 1
530536
), "Summary was not skipped"
531537

532538

@@ -622,7 +628,7 @@ def files_filter(f) -> bool:
622628

623629
agent_test_settings.agent.callbacks = callbacks
624630

625-
session = PQASession(question="What is is a self-explanatory model?")
631+
session = PQASession(question="What is a self-explanatory model?")
626632
env_state = EnvironmentState(docs=Docs(), session=session)
627633
built_index = await get_directory_index(settings=agent_test_settings)
628634
assert await built_index.count, "Index build did not work"
@@ -730,11 +736,11 @@ def new_status(state: EnvironmentState) -> str:
730736
for context in session.contexts:
731737
if context.question != new_question:
732738
assert (
733-
context.context[:20] not in response
739+
context.context[:30] not in response
734740
), "gather_evidence should not return any contexts for the old question"
735741
assert (
736742
sum(
737-
(1 if (context.context[:20] in response) else 0)
743+
(1 if (context.context[:30] in response) else 0)
738744
for context in session.contexts
739745
if context.question == new_question
740746
)

tests/test_configs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_settings_default_instantiation(tmpdir, subtests: SubTests) -> None:
6767
# Also let's check our default settings work fine with round-trip JSON serialization
6868
serde_default_settings = Settings(**default_settings.model_dump(mode="json"))
6969
for setting in (default_settings, serde_default_settings):
70-
assert "gpt-" in setting.llm
70+
assert any(x in setting.llm for x in ("gpt-", "claude-"))
7171
assert setting.answer.evidence_k == 10
7272
assert HOME_DIR in str(setting.agent.index.index_directory)
7373
assert ".pqa" in str(setting.agent.index.index_directory)

tests/test_paperqa.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,7 +1635,8 @@ async def test_querying_tables(stub_data_dir: Path) -> None:
16351635
assert all(
16361636
[m.data for m in t.media] for t in used_texts
16371637
), "Expected image data to be present in the used contexts"
1638-
assert any(x in session.answer for x in ("1.0 mm", "1.0-mm"))
1638+
# Check for 1.0mm, 1.0-mm, 1.0 mm
1639+
assert re.search(r"1\.0[ -]?mm", session.answer)
16391640
assert session.cost > 0
16401641

16411642
# Filter contexts for HTTP requests, and ensure no images are present
@@ -2239,12 +2240,12 @@ def test_docdetails_doc_id_roundtrip() -> None:
22392240
@pytest.mark.asyncio
22402241
async def test_partitioning_fn_docs(use_partition: bool) -> None:
22412242
settings = Settings.from_name("fast")
2242-
settings.answer.evidence_k = 2 # limit to only 2
2243+
settings.answer.evidence_k = 2 # Match positive or negative statement count below
22432244

22442245
# imagine we have some special selection we want to
22452246
# embedding rank by itself
22462247
def partition_by_citation(t: Embeddable) -> int:
2247-
if isinstance(t, Text) and "special" in t.doc.citation:
2248+
if isinstance(t, Text) and "negative" in t.doc.citation:
22482249
return 1
22492250
return 0
22502251

@@ -2257,9 +2258,11 @@ def partition_by_citation(t: Embeddable) -> int:
22572258
), "We want this test to cover NumpyVectorStore"
22582259

22592260
# add docs that we can use our partitioning function on
2260-
positive_statements_doc = Doc(docname="stub", citation="stub", dockey="stub")
2261+
positive_statements_doc = Doc(
2262+
docname="positive", citation="positive", dockey="positive"
2263+
)
22612264
negative_statements_doc = Doc(
2262-
docname="special", citation="special", dockey="special"
2265+
docname="negative", citation="negative", dockey="negative"
22632266
)
22642267
texts = []
22652268
for i, (statement, doc) in enumerate(
@@ -2275,10 +2278,11 @@ def partition_by_citation(t: Embeddable) -> int:
22752278
await settings.get_embedding_model().embed_documents([texts[-1].text])
22762279
)[0]
22772280
await docs.aadd_texts(
2278-
texts=[t for t in texts if t.doc.docname == "stub"], doc=positive_statements_doc
2281+
texts=[t for t in texts if t.doc.docname == "positive"],
2282+
doc=positive_statements_doc,
22792283
)
22802284
await docs.aadd_texts(
2281-
texts=[t for t in texts if t.doc.docname == "special"],
2285+
texts=[t for t in texts if t.doc.docname == "negative"],
22822286
doc=negative_statements_doc,
22832287
)
22842288

@@ -2330,10 +2334,11 @@ def partition_by_citation(t: Embeddable) -> int:
23302334
# with partitioning, we are forcing them to be interleaved, thus
23312335
# at least one "I don't like X" statements will be in the top 2
23322336
session = await docs.aget_evidence(
2333-
"What do I like?", settings=settings, partitioning_fn=partitioning_fn
2337+
"What do I like or dislike?", settings=settings, partitioning_fn=partitioning_fn
23342338
)
23352339
assert docs.texts_index.texts == docs.texts == texts
23362340

2341+
assert session.contexts, "Test requires contexts to be made"
23372342
if use_partition:
23382343
assert any(
23392344
"don't" in c.text.text for c in session.contexts

0 commit comments

Comments
 (0)