Skip to content
This repository has been archived by the owner on Sep 11, 2024. It is now read-only.

Commit

Permalink
test: add integration test for retrieval (#4)
Browse files Browse the repository at this point in the history
Signed-off-by: Panos Vagenas <[email protected]>
  • Loading branch information
vagenas authored Aug 28, 2024
1 parent 53f32b9 commit c0f6666
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/actions/setup-poetry/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ runs:
python-version: ${{ inputs.python-version }}
cache: 'poetry'
- name: Install dependencies
run: poetry install
run: poetry install --all-extras
shell: bash
4 changes: 3 additions & 1 deletion .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@ jobs:
- uses: ./.github/actions/setup-poetry
with:
python-version: ${{ matrix.python-version }}
- name: Run styling check
- name: Run checks
run: poetry run pre-commit run --all-files
- name: Run integration tests
run: poetry run pytest tests/integration
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,3 @@ jobs:
# - uses: ./.github/actions/setup-poetry
# - name: Build docs
# run: poetry run mkdocs build --verbose --clean

2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ repos:
files: '\.py$'
- id: pytest
name: Pytest
entry: poetry run pytest tests
entry: poetry run pytest tests/unit
pass_filenames: false
language: system
files: '\.py$'
Expand Down
11 changes: 9 additions & 2 deletions quackling/llama_index/node_parsers/hier_node_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ class HierarchicalNodeParser(NodeParser):
include_metadata: bool = Field(
default=False, description="Whether or not to consider metadata when splitting."
)
id_gen_seed: int | None = Field(
default=None,
description="ID generation seed; should typically be left to default `None`, which seeds on current timestamp; only set if you want the instance to generate a reproducible ID sequence e.g. for testing", # noqa: 501
)

def _parse_nodes(
self,
nodes: Sequence[BaseNode],
show_progress: bool = False,
id_gen_seed: int | None = None,
**kwargs: Any,
) -> list[BaseNode]:
# based on llama_index.core.node_parser.interface.TextSplitter
Expand All @@ -47,7 +50,11 @@ def _parse_nodes(
excl_meta_embed = NodeMetadata.ExcludedKeys.EMBED
excl_meta_llm = NodeMetadata.ExcludedKeys.LLM

seed = id_gen_seed if id_gen_seed is not None else datetime.now().timestamp()
seed = (
self.id_gen_seed
if self.id_gen_seed is not None
else datetime.now().timestamp()
)
rd = Random()
rd.seed(seed)

Expand Down
7 changes: 7 additions & 0 deletions tests/integration/data/0_out_retrieval_results.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"root": [
"4 ANNOTATION CAMPAIGN\nThe complete annotation guideline is over 100 pages long and a detailed description is obviously out of scope for this paper. Nevertheless, it will be made publicly available alongside with DocLayNet for future reference.",
"1 INTRODUCTION\n(4) Redundant Annotations : A fraction of the pages in the DocLayNet data set carry more than one human annotation.",
"4 ANNOTATION CAMPAIGN\nPhase 4: Production annotation. The previously selected 80K pages were annotated with the defined 11 class labels by 32 annotators. This production phase took around three months to complete. All annotations were created online through CCS, which visualises the programmatic PDF text-cells as an overlay on the page. The page annotation are obtained by drawing rectangular bounding-boxes, as shown in Figure 3. With regard to the annotation practices, we implemented a few constraints and capabilities on the tooling level. First, we only allow non-overlapping, vertically oriented, rectangular boxes. For the large majority of documents, this constraint was sufficient and it speeds up the annotation considerably in comparison with arbitrary segmentation shapes. Second, annotator staff were not able to see each other's annotations. This was enforced by design to avoid any bias in the annotation, which could skew the numbers of the inter-annotator agreement (see Table 1). We wanted"
]
}
49 changes: 49 additions & 0 deletions tests/integration/test_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import json
from tempfile import TemporaryDirectory

from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.milvus import MilvusVectorStore

from quackling.llama_index.node_parsers.hier_node_parser import HierarchicalNodeParser
from quackling.llama_index.readers.docling_reader import DoclingReader


def test_retrieval():
FILE_PATH = "https://arxiv.org/pdf/2206.01062" # DocLayNet paper
QUERY = "How many pages were human annotated?"
TOP_K = 3
HF_EMBED_MODEL_ID = "BAAI/bge-small-en-v1.5"
ID_GEN_SEED = 42
MILVUS_DB_FNAME = "milvus_demo.db"
MILVUS_COLL_NAME = "quackling_test_coll"

reader = DoclingReader(parse_type=DoclingReader.ParseType.JSON)
node_parser = HierarchicalNodeParser(id_gen_seed=ID_GEN_SEED)
embed_model = HuggingFaceEmbedding(model_name=HF_EMBED_MODEL_ID)

with TemporaryDirectory() as tmp_dir:
vector_store = MilvusVectorStore(
uri=f"{tmp_dir}/{MILVUS_DB_FNAME}",
collection_name=MILVUS_COLL_NAME,
dim=len(embed_model.get_text_embedding("hi")),
overwrite=True,
)
docs = reader.load_data(file_path=[FILE_PATH])
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
documents=docs,
storage_context=storage_context,
transformations=[node_parser],
embed_model=embed_model,
)
retriever = index.as_retriever(
similarity_top_k=TOP_K,
vector_store_query_mode=VectorStoreQueryMode.DEFAULT,
)
retr_res = retriever.retrieve(QUERY)
act_data = dict(root=[n.text for n in retr_res])
with open("tests/integration/data/0_out_retrieval_results.json") as f:
exp_data = json.load(fp=f)
assert exp_data == act_data
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,24 @@


def test_chunk_without_metadata():
with open("tests/data/0_inp_dl_doc.json") as f:
with open("tests/unit/data/0_inp_dl_doc.json") as f:
data_json = f.read()
dl_doc = DLDocument.model_validate_json(data_json)
chunker = HierarchicalChunker(include_metadata=False)
chunks = chunker.chunk(dl_doc=dl_doc)
act_data = dict(root=[n.model_dump() for n in chunks])
with open("tests/data/0_out_chunks_wout_meta.json") as f:
with open("tests/unit/data/0_out_chunks_wout_meta.json") as f:
exp_data = json.load(fp=f)
assert exp_data == act_data


def test_chunk_with_metadata():
with open("tests/data/0_inp_dl_doc.json") as f:
with open("tests/unit/data/0_inp_dl_doc.json") as f:
data_json = f.read()
dl_doc = DLDocument.model_validate_json(data_json)
chunker = HierarchicalChunker(include_metadata=True)
chunks = chunker.chunk(dl_doc=dl_doc)
act_data = dict(root=[n.model_dump() for n in chunks])
with open("tests/data/0_out_chunks_with_meta.json") as f:
with open("tests/unit/data/0_out_chunks_with_meta.json") as f:
exp_data = json.load(fp=f)
assert exp_data == act_data
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@


def test_node_parse():
with open("tests/data/1_inp_li_doc.json") as f:
with open("tests/unit/data/1_inp_li_doc.json") as f:
data_json = f.read()
li_doc = LIDocument.from_json(data_json)
node_parser = HierarchicalNodeParser()
nodes = node_parser._parse_nodes(nodes=[li_doc], id_gen_seed=42)
node_parser = HierarchicalNodeParser(id_gen_seed=42)
nodes = node_parser._parse_nodes(nodes=[li_doc])
act_data = dict(root=[n.dict() for n in nodes])
with open("tests/data/1_out_nodes.json") as f:
with open("tests/unit/data/1_out_nodes.json") as f:
exp_data = json.load(fp=f)
assert exp_data == act_data

0 comments on commit c0f6666

Please sign in to comment.