diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..658ce2a --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,71 @@ +name: CI + +on: + pull_request: + branches: + - main + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: "3.10" + + - name: Install project dependencies + run: pip install -r requirements-ci.txt + + - name: Run Ruff lint + run: ruff check . + + format: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: "3.10" + + - name: Install project dependencies + run: pip install -r requirements-ci.txt + + - name: Run Ruff format + run: ruff format . + + - name: Check for uncommitted changes (formatting) + run: git diff --exit-code + + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python_version: + - "3.9" + - "3.10" + - "3.11" + - "3.12" + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python_version }} + + - name: Install project dependencies + run: pip install -r requirements-ci.txt + + - name: Run tests + run: pytest + diff --git a/README.md b/README.md index 33e48b6..c8b9cd3 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # rerankers -![Python Versions](https://img.shields.io/badge/Python-3.8_3.9_3.10_3.11-blue) +![Python Versions](https://img.shields.io/badge/Python-3.9_3.10_3.11_3.12-blue) [![Downloads](https://static.pepy.tech/badge/rerankers/month)](https://pepy.tech/project/rerankers) [![Twitter Follow](https://img.shields.io/twitter/follow/bclavie?style=social)](https://twitter.com/bclavie) diff --git a/examples/langchain_integration.ipynb b/examples/langchain_integration.ipynb index 248f794..77d186c 100644 --- a/examples/langchain_integration.ipynb +++ b/examples/langchain_integration.ipynb @@ -262,7 +262,6 @@ ], "source": [ "from langchain.retrievers import ContextualCompressionRetriever\n", - "from langchain_openai import OpenAI\n", "\n", "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", "texts = text_splitter.split_documents(documents)\n", @@ -274,7 +273,7 @@ "\n", "\n", "compressed_docs = compression_retriever.get_relevant_documents(\n", - " \"What did the president say about the minimum wage?\"\n", + " \"What did the president say about the minimum wage?\"\n", ")\n", "\n", "pretty_print_docs(compressed_docs)" diff --git a/examples/overview.ipynb b/examples/overview.ipynb index 3e98b95..b20b35d 100644 --- a/examples/overview.ipynb +++ b/examples/overview.ipynb @@ -39,6 +39,7 @@ "\n", "# Let's do this manually, so we don't need to install pydotenv\n", "import os\n", + "\n", "with open(\".env\", \"r\") as file:\n", " for line in file:\n", " if line.strip() and not line.startswith(\"#\"):\n", @@ -62,7 +63,10 @@ "outputs": [], "source": [ "query = \"Gone with the wind is an absolute masterpiece\"\n", - "docs = [\"Gone with the wind is a masterclass in bad storytelling.\", \"Gone with the wind is an all-time classic\"]" + "docs = [\n", + " \"Gone with the wind is a masterclass in bad storytelling.\",\n", + " \"Gone with the wind is an all-time classic\",\n", + "]" ] }, { @@ -120,7 +124,7 @@ } ], "source": [ - "ranker = Reranker('cross-encoder')\n" + "ranker = Reranker(\"cross-encoder\")" ] }, { @@ -147,7 +151,7 @@ } ], "source": [ - "ranker = Reranker('cross-encoder', verbose=0)" + "ranker = Reranker(\"cross-encoder\", verbose=0)" ] }, { @@ -264,7 +268,11 @@ } ], "source": [ - "results = ranker.rank(query=query, docs=docs, doc_ids=[\"The Not-So Similar Document\", \"The Similar Document\"])\n", + "results = ranker.rank(\n", + " query=query,\n", + " docs=docs,\n", + " doc_ids=[\"The Not-So Similar Document\", \"The Similar Document\"],\n", + ")\n", "results.top_k(1)" ] }, @@ -317,7 +325,7 @@ } ], "source": [ - "ranker = Reranker('cross-encoder/ms-marco-MiniLM-L-6-v2', model_type=\"cross-encoder\")" + "ranker = Reranker(\"cross-encoder/ms-marco-MiniLM-L-6-v2\", model_type=\"cross-encoder\")" ] }, { @@ -378,7 +386,7 @@ } ], "source": [ - "ranker = Reranker('cross-encoder', lang='fr')" + "ranker = Reranker(\"cross-encoder\", lang=\"fr\")" ] }, { @@ -409,13 +417,14 @@ } ], "source": [ - "ranker = Reranker('cross-encoder/ms-marco-MiniLM-L-6-v2',\n", + "ranker = Reranker(\n", + " \"cross-encoder/ms-marco-MiniLM-L-6-v2\",\n", " model_type=\"cross-encoder\",\n", - " verbose = 1, # How verbose the reranker will be. Defaults to 1, setting it to 0 will suppress most messages.\n", - " dtype = None, # Which dtype the model should use. If None will figure out if your platform + model combo supports fp16 and use it if so, other fp32.\n", - " device = None, # Which device the model should use. If None will figure out what the most powerful supported platform available is (cuda > mps > cpu)\n", - " batch_size = 16, # The batch size the model will use. Defaults to 16\n", - " )" + " verbose=1, # How verbose the reranker will be. Defaults to 1, setting it to 0 will suppress most messages.\n", + " dtype=None, # Which dtype the model should use. If None will figure out if your platform + model combo supports fp16 and use it if so, other fp32.\n", + " device=None, # Which device the model should use. If None will figure out what the most powerful supported platform available is (cuda > mps > cpu)\n", + " batch_size=16, # The batch size the model will use. Defaults to 16\n", + ")" ] }, { @@ -463,9 +472,9 @@ ], "source": [ "# Jina\n", - "ranker = Reranker(\"jina\", api_key=os.environ['JINA_API_KEY'])\n", + "ranker = Reranker(\"jina\", api_key=os.environ[\"JINA_API_KEY\"])\n", "results = ranker.rank(query=query, docs=docs)\n", - "results.top_k(1)\n" + "results.top_k(1)" ] }, { @@ -495,7 +504,7 @@ ], "source": [ "# Cohere\n", - "ranker = Reranker(\"cohere\", api_key = os.environ['COHERE_API_KEY'])\n", + "ranker = Reranker(\"cohere\", api_key=os.environ[\"COHERE_API_KEY\"])\n", "results = ranker.rank(query=query, docs=docs)\n", "results.top_k(1)" ] @@ -538,11 +547,15 @@ ], "source": [ "# Cohere\n", - "ranker = Reranker(\"cohere\", lang=\"en\", api_key = os.environ['COHERE_API_KEY'])\n", - "ranker.rank(query=\"Tell me about lord of the rings\",\n", - " docs=[\"Dune is an incredibly confusing masterpiece in worldbuilding...\",\n", - " \"The silmarillion is a prequel to the Lord of The Rings...\",\n", - " \"Green Lantern uses a powerful ring to rule over his planet...\"])" + "ranker = Reranker(\"cohere\", lang=\"en\", api_key=os.environ[\"COHERE_API_KEY\"])\n", + "ranker.rank(\n", + " query=\"Tell me about lord of the rings\",\n", + " docs=[\n", + " \"Dune is an incredibly confusing masterpiece in worldbuilding...\",\n", + " \"The silmarillion is a prequel to the Lord of The Rings...\",\n", + " \"Green Lantern uses a powerful ring to rule over his planet...\",\n", + " ],\n", + ")" ] }, { @@ -570,7 +583,11 @@ "source": [ "# wrap in a try/except, as we don't have a fine-tuned model to use for this example!\n", "try:\n", - " ranker = Reranker(\"my-finetuned-model-name\", api_provider=\"cohere\", api_key = os.environ['COHERE_API_KEY'])\n", + " ranker = Reranker(\n", + " \"my-finetuned-model-name\",\n", + " api_provider=\"cohere\",\n", + " api_key=os.environ[\"COHERE_API_KEY\"],\n", + " )\n", "except:\n", " pass" ] @@ -648,7 +665,7 @@ } ], "source": [ - "ranker = Reranker('t5')\n", + "ranker = Reranker(\"t5\")\n", "results = ranker.rank(query=query, docs=docs)\n", "results.top_k(2)" ] @@ -675,7 +692,7 @@ } ], "source": [ - "ranker = Reranker(\"unicamp-dl/ptt5-base-pt-msmarco-10k-v2\", model_type='t5', verbose=0)" + "ranker = Reranker(\"unicamp-dl/ptt5-base-pt-msmarco-10k-v2\", model_type=\"t5\", verbose=0)" ] }, { @@ -712,16 +729,17 @@ } ], "source": [ - "ranker = Reranker('t5',\n", + "ranker = Reranker(\n", + " \"t5\",\n", " model_type=\"t5\",\n", - " verbose = 1, # How verbose the reranker will be. Defaults to 1, setting it to 0 will suppress most messages.\n", - " dtype = None, # Which dtype the model should use. If None will figure out if your platform + model combo supports fp16 and use it if so, other fp32.\n", - " device = None, # Which device the model should use. If None will figure out what the most powerful supported platform available is (cuda > mps > cpu)\n", - " batch_size = 16, # The batch size the model will use. Defaults to 16\n", - " token_false = \"auto\", # The output token corresponding to non-relevance.\n", - " token_true = \"auto\", # The output token corresponding to relevance.\n", - " return_logits = False, # Whether to return a normalised score or the raw logit for `token_true`.\n", - " )\n" + " verbose=1, # How verbose the reranker will be. Defaults to 1, setting it to 0 will suppress most messages.\n", + " dtype=None, # Which dtype the model should use. If None will figure out if your platform + model combo supports fp16 and use it if so, other fp32.\n", + " device=None, # Which device the model should use. If None will figure out what the most powerful supported platform available is (cuda > mps > cpu)\n", + " batch_size=16, # The batch size the model will use. Defaults to 16\n", + " token_false=\"auto\", # The output token corresponding to non-relevance.\n", + " token_true=\"auto\", # The output token corresponding to relevance.\n", + " return_logits=False, # Whether to return a normalised score or the raw logit for `token_true`.\n", + ")" ] }, { @@ -762,7 +780,8 @@ ], "source": [ "from rerankers import Reranker\n", - "ranker = Reranker(\"rankgpt\", api_key = os.environ['OPENAI_API_KEY'])\n" + "\n", + "ranker = Reranker(\"rankgpt\", api_key=os.environ[\"OPENAI_API_KEY\"])" ] }, { @@ -810,7 +829,8 @@ ], "source": [ "from rerankers import Reranker\n", - "ranker = Reranker(\"rankgpt3\", api_key = os.environ['OPENAI_API_KEY'])\n", + "\n", + "ranker = Reranker(\"rankgpt3\", api_key=os.environ[\"OPENAI_API_KEY\"])\n", "results = ranker.rank(query=query, docs=docs)\n", "results.top_k(1)" ] @@ -840,16 +860,21 @@ "source": [ "# LiteLLM uses env variables\n", "import os\n", - "os.environ['AZURE_API_KEY'] = \"\"\n", + "\n", + "os.environ[\"AZURE_API_KEY\"] = \"\"\n", "os.environ[\"AZURE_API_BASE\"] = \"\"\n", "os.environ[\"AZURE_API_VERSION\"] = \"\"\n", "deployment_name = \"my-azure-gpt-deployment\"\n", "\n", "# Just like Cohere's finetuned rankers above -- we try/except this as we're not actually running an Azure OpenAI model in this example!\n", "try:\n", - " ranker = Reranker(f\"azure/{deployment_name}\", model_type=\"rankgpt\", api_key = os.environ['AZURE_API_KEY'])\n", + " ranker = Reranker(\n", + " f\"azure/{deployment_name}\",\n", + " model_type=\"rankgpt\",\n", + " api_key=os.environ[\"AZURE_API_KEY\"],\n", + " )\n", "except:\n", - " pass\n" + " pass" ] }, { @@ -893,7 +918,8 @@ ], "source": [ "from rerankers import Reranker\n", - "ranker = Reranker('rankllm', api_key=os.environ['OPENAI_API_KEY'])\n", + "\n", + "ranker = Reranker(\"rankllm\", api_key=os.environ[\"OPENAI_API_KEY\"])\n", "ranker.rank(query=query, docs=docs)" ] }, @@ -935,7 +961,9 @@ } ], "source": [ - "ranker = Reranker('gpt-4-turbo', model_type='rankllm', api_key=os.environ['OPENAI_API_KEY'])\n", + "ranker = Reranker(\n", + " \"gpt-4-turbo\", model_type=\"rankllm\", api_key=os.environ[\"OPENAI_API_KEY\"]\n", + ")\n", "ranker.rank(query=query, docs=docs)" ] }, @@ -990,7 +1018,8 @@ ], "source": [ "from rerankers import Reranker\n", - "ranker = Reranker('colbert')\n", + "\n", + "ranker = Reranker(\"colbert\")\n", "results = ranker.rank(query=query, docs=docs)\n", "results.top_k(1)" ] @@ -1021,7 +1050,7 @@ } ], "source": [ - "ranker = Reranker('antoinelouis/colbertv2-camembert-L4-mmarcoFR', model_type=\"colbert\")" + "ranker = Reranker(\"antoinelouis/colbertv2-camembert-L4-mmarcoFR\", model_type=\"colbert\")" ] }, { @@ -1052,15 +1081,16 @@ } ], "source": [ - "ranker = Reranker('colbert',\n", + "ranker = Reranker(\n", + " \"colbert\",\n", " model_type=\"colbert\",\n", - " verbose = 1, # How verbose the reranker will be. Defaults to 1, setting it to 0 will suppress most messages.\n", - " dtype = None, # Which dtype the model should use. If None will figure out if your platform + model combo supports fp16 and use it if so, other fp32.\n", - " device = None, # Which device the model should use. If None will figure out what the most powerful supported platform available is (cuda > mps > cpu)\n", - " batch_size = 16, # The batch size the model will use. Defaults to 16\n", - " query_token = \"[unused0]\", # A ColBERT-specific argument. The token that your model prepends to queries.\n", - " document_token = \"[unused1]\", # A ColBERT-specific argument. The token that your model prepends to documents.\n", - " )" + " verbose=1, # How verbose the reranker will be. Defaults to 1, setting it to 0 will suppress most messages.\n", + " dtype=None, # Which dtype the model should use. If None will figure out if your platform + model combo supports fp16 and use it if so, other fp32.\n", + " device=None, # Which device the model should use. If None will figure out what the most powerful supported platform available is (cuda > mps > cpu)\n", + " batch_size=16, # The batch size the model will use. Defaults to 16\n", + " query_token=\"[unused0]\", # A ColBERT-specific argument. The token that your model prepends to queries.\n", + " document_token=\"[unused1]\", # A ColBERT-specific argument. The token that your model prepends to documents.\n", + ")" ] }, { @@ -1101,7 +1131,7 @@ } ], "source": [ - "ranker = Reranker('flashrank') # Defaults to MiniLM-L12-v2\n", + "ranker = Reranker(\"flashrank\") # Defaults to MiniLM-L12-v2\n", "results = ranker.rank(query=query, docs=docs)\n", "results.top_k(1)" ] @@ -1152,9 +1182,9 @@ } ], "source": [ - "ranker = Reranker('ms-marco-TinyBERT-L-2-v2', model_type='flashrank')\n", + "ranker = Reranker(\"ms-marco-TinyBERT-L-2-v2\", model_type=\"flashrank\")\n", "results = ranker.rank(query=query, docs=docs)\n", - "results\n" + "results" ] }, { diff --git a/pyproject.toml b/pyproject.toml index 1fab24f..a9d9eb4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,3 @@ - [build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" @@ -10,6 +9,20 @@ packages = [ "rerankers.integrations", ] +[tool.ruff] +exclude = [ + "examples/*" +] + +[tool.mypy] +python_version = "3.12" +disallow_untyped_defs = true +namespace_packages = true + +[[tool.mypy.overrides]] +module = ["requests"] +ignore_missing_imports = true + [project] name = "rerankers" @@ -73,7 +86,7 @@ rankllm = [ "nmslib-metabrainz; python_version >= '3.10'", "rank-llm; python_version >= '3.10'" ] -dev = ["ruff", "isort", "pytest", "ipyprogress", "ipython", "ranx", "ir_datasets", "srsly"] +dev = ["ruff", "pytest", "ipyprogress", "ipython", "ranx", "ir_datasets", "srsly"] [project.urls] -"Homepage" = "https://github.com/answerdotai/rerankers" \ No newline at end of file +"Homepage" = "https://github.com/answerdotai/rerankers" diff --git a/requirements-ci.txt b/requirements-ci.txt new file mode 100644 index 0000000..23dbb10 --- /dev/null +++ b/requirements-ci.txt @@ -0,0 +1,4 @@ +ruff +pytest + +.[transformers] diff --git a/rerankers/models/__init__.py b/rerankers/models/__init__.py index cd0439d..5842c0e 100644 --- a/rerankers/models/__init__.py +++ b/rerankers/models/__init__.py @@ -1,4 +1,6 @@ -AVAILABLE_RANKERS = {} +from typing import Any + +AVAILABLE_RANKERS: dict[str, Any] = {} try: from rerankers.models.transformer_ranker import TransformerRanker diff --git a/rerankers/models/api_rankers.py b/rerankers/models/api_rankers.py index bf3a667..495b30d 100644 --- a/rerankers/models/api_rankers.py +++ b/rerankers/models/api_rankers.py @@ -16,26 +16,28 @@ "mixedbread.ai": "https://api.mixedbread.ai/v1/reranking", } -DOCUMENT_KEY_MAPPING = { - "mixedbread.ai": "input", - "text-embeddings-inference":"texts" -} +DOCUMENT_KEY_MAPPING = {"mixedbread.ai": "input", "text-embeddings-inference": "texts"} RETURN_DOCUMENTS_KEY_MAPPING = { - "mixedbread.ai":"return_input", - "text-embeddings-inference":"return_text" + "mixedbread.ai": "return_input", + "text-embeddings-inference": "return_text", } RESULTS_KEY_MAPPING = { "voyage": "data", "mixedbread.ai": "data", - "text-embeddings-inference": None -} -SCORE_KEY_MAPPING = { - "mixedbread.ai": "score", - "text-embeddings-inference":"score" + "text-embeddings-inference": None, } +SCORE_KEY_MAPPING = {"mixedbread.ai": "score", "text-embeddings-inference": "score"} + class APIRanker(BaseRanker): - def __init__(self, model: str, api_key: str, api_provider: str, verbose: int = 1, url: str = None): + def __init__( + self, + model: str, + api_key: str, + api_provider: str, + verbose: int = 1, + url: str = None, + ): self.api_key = api_key self.model = model self.api_provider = api_provider.lower() @@ -48,7 +50,6 @@ def __init__(self, model: str, api_key: str, api_provider: str, verbose: int = 1 } self.url = url if url else URLS[self.api_provider] - def _get_document_text(self, r: dict) -> str: if self.api_provider == "voyage": return r["document"] @@ -60,14 +61,16 @@ def _get_document_text(self, r: dict) -> str: return r["document"]["text"] def _get_score(self, r: dict) -> float: - score_key = SCORE_KEY_MAPPING.get(self.api_provider,"relevance_score") + score_key = SCORE_KEY_MAPPING.get(self.api_provider, "relevance_score") return r[score_key] def _parse_response( - self, response: dict, docs: List[Document], + self, + response: dict, + docs: List[Document], ) -> RankedResults: ranked_docs = [] - results_key = RESULTS_KEY_MAPPING.get(self.api_provider,"results") + results_key = RESULTS_KEY_MAPPING.get(self.api_provider, "results") print(response) for i, r in enumerate(response[results_key] if results_key else response): @@ -95,13 +98,14 @@ def rank( results = self._parse_response(response.json(), docs) return RankedResults(results=results, query=query, has_scores=True) - def _format_payload(self, query: str, docs: List[str]) -> str: top_key = ( "top_n" if self.api_provider not in ["voyage", "mixedbread.ai"] else "top_k" ) - documents_key = DOCUMENT_KEY_MAPPING.get(self.api_provider,"documents") - return_documents_key = RETURN_DOCUMENTS_KEY_MAPPING.get(self.api_provider,"return_documents") + documents_key = DOCUMENT_KEY_MAPPING.get(self.api_provider, "documents") + return_documents_key = RETURN_DOCUMENTS_KEY_MAPPING.get( + self.api_provider, "return_documents" + ) payload = { "model": self.model, diff --git a/rerankers/models/colbert_ranker.py b/rerankers/models/colbert_ranker.py index 57f02d9..025f1c2 100644 --- a/rerankers/models/colbert_ranker.py +++ b/rerankers/models/colbert_ranker.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -from transformers import BertPreTrainedModel, BertModel, AutoModel, AutoTokenizer +from transformers import BertPreTrainedModel, BertModel, AutoTokenizer from typing import List, Optional, Union from math import ceil diff --git a/rerankers/models/ranker.py b/rerankers/models/ranker.py index 6c8f45a..95f53a4 100644 --- a/rerankers/models/ranker.py +++ b/rerankers/models/ranker.py @@ -33,10 +33,10 @@ async def rank_async( docs: List[str], doc_ids: Optional[Union[List[str], str]] = None, ) -> RankedResults: - - loop = get_event_loop() - return await loop.run_in_executor(None, partial(self.rank, query, docs, doc_ids)) + return await loop.run_in_executor( + None, partial(self.rank, query, docs, doc_ids) + ) def as_langchain_compressor(self, k: int = 10): try: diff --git a/rerankers/models/t5ranker.py b/rerankers/models/t5ranker.py index 30a4e18..5be6c45 100644 --- a/rerankers/models/t5ranker.py +++ b/rerankers/models/t5ranker.py @@ -14,8 +14,6 @@ from rerankers.documents import Document -import torch - from rerankers.results import RankedResults, Result from rerankers.utils import ( vprint, @@ -89,7 +87,7 @@ def __init__( token_false: str = "auto", token_true: str = "auto", return_logits: bool = False, - inputs_template: str = "Query: {query} Document: {text} Relevant:" + inputs_template: str = "Query: {query} Document: {text} Relevant:", ): """ Implementation of the key functions from https://github.com/unicamp-dl/InRanker/blob/main/inranker/rankers.py @@ -197,8 +195,7 @@ def _get_scores( total=ceil(len(docs) / batch_size), ): queries_documents = [ - self.inputs_template.format(query=query, text=text) - for text in batch + self.inputs_template.format(query=query, text=text) for text in batch ] tokenized = self.tokenizer( queries_documents, diff --git a/rerankers/reranker.py b/rerankers/reranker.py index 5c7599b..4e68f46 100644 --- a/rerankers/reranker.py +++ b/rerankers/reranker.py @@ -52,7 +52,9 @@ PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai", "text-embeddings-inference"] -def _get_api_provider(model_name: str, model_type: Optional[str] = None) -> str: +def _get_api_provider( + model_name: str, model_type: Optional[str] = None +) -> Optional[str]: if model_type in PROVIDERS or any(provider in model_name for provider in PROVIDERS): return model_type or next( (provider for provider in PROVIDERS if provider in model_name), None @@ -137,7 +139,7 @@ def _get_defaults( model_type: Optional[str] = None, lang: str = "en", verbose: int = 1, -) -> str: +) -> tuple[Optional[str], Optional[str]]: if model_name in DEFAULTS.keys(): print(f"Loading default {model_name} model for language {lang}") try: diff --git a/tests/consistency_notebooks/test_colbert.ipynb b/tests/consistency_notebooks/test_colbert.ipynb index 229b6b4..46be64d 100644 --- a/tests/consistency_notebooks/test_colbert.ipynb +++ b/tests/consistency_notebooks/test_colbert.ipynb @@ -58,8 +58,8 @@ "source": [ "import srsly\n", "\n", - "corpus = [x for x in srsly.read_jsonl('./data/scifact/corpus.jsonl')]\n", - "queries = [x for x in srsly.read_jsonl('./data/scifact/queries.jsonl')]\n", + "corpus = [x for x in srsly.read_jsonl(\"./data/scifact/corpus.jsonl\")]\n", + "queries = [x for x in srsly.read_jsonl(\"./data/scifact/queries.jsonl\")]\n", "\n", "corpus[0]" ] @@ -79,7 +79,7 @@ } ], "source": [ - "ranker = Reranker('colbert', device='cuda', verbose=0)" + "ranker = Reranker(\"colbert\", device=\"cuda\", verbose=0)" ] }, { @@ -88,8 +88,7 @@ "metadata": {}, "outputs": [], "source": [ - "top100 = srsly.read_json('./data/scifact/scifact_top_100.json')\n", - "\n" + "top100 = srsly.read_json(\"./data/scifact/scifact_top_100.json\")" ] }, { @@ -98,7 +97,7 @@ "metadata": {}, "outputs": [], "source": [ - "corpus_map = {x['_id']: f\"{x['title']} {x['text']}\" for x in corpus}" + "corpus_map = {x[\"_id\"]: f\"{x['title']} {x['text']}\" for x in corpus}" ] }, { @@ -115,15 +114,16 @@ } ], "source": [ - "qrels_dict = dict(qrels)\n", - "queries = [q for q in queries if q['_id'] in qrels_dict]\n", "from tqdm import tqdm\n", "\n", + "qrels_dict = dict(qrels)\n", + "queries = [q for q in queries if q[\"_id\"] in qrels_dict]\n", + "\n", "scores = {}\n", "for q in tqdm(queries):\n", - " doc_ids = top100[q['_id']]\n", + " doc_ids = top100[q[\"_id\"]]\n", " docs = [corpus_map[x] for x in doc_ids]\n", - " scores[q['_id']] = ranker.rank(q['text'], docs, doc_ids=doc_ids)\n" + " scores[q[\"_id\"]] = ranker.rank(q[\"text\"], docs, doc_ids=doc_ids)" ] }, { @@ -154,12 +154,15 @@ ], "source": [ "from ranx import evaluate\n", - "evaluation_score = evaluate(qrels, run, 'ndcg@10')\n", - "litterature_result = 0.693 # From ColBERTv2 Paper https://arxiv.org/abs/2112.01488\n", + "\n", + "evaluation_score = evaluate(qrels, run, \"ndcg@10\")\n", + "litterature_result = 0.693 # From ColBERTv2 Paper https://arxiv.org/abs/2112.01488\n", "if abs(evaluation_score - litterature_result) > 0.01:\n", - " print(f\"Score {evaluation_score:0.3f} differs by more than 0.01 from the the reported score.\")\n", + " print(\n", + " f\"Score {evaluation_score:0.3f} differs by more than 0.01 from the the reported score.\"\n", + " )\n", "else:\n", - " print(f\"Score is within 0.01 NDCG@10 of the reported score!\")\n" + " print(\"Score is within 0.01 NDCG@10 of the reported score!\")" ] } ], diff --git a/tests/consistency_notebooks/test_crossenc.ipynb b/tests/consistency_notebooks/test_crossenc.ipynb index 5c0f604..b0fac80 100644 --- a/tests/consistency_notebooks/test_crossenc.ipynb +++ b/tests/consistency_notebooks/test_crossenc.ipynb @@ -6,7 +6,7 @@ "metadata": {}, "outputs": [], "source": [ - "from ranx import Qrels, Run\n" + "from ranx import Qrels, Run" ] }, { @@ -63,7 +63,8 @@ ], "source": [ "from rerankers import Reranker\n", - "ranker = Reranker('castorini/monot5-base-msmarco-10k', device='cuda', batch_size=128)" + "\n", + "ranker = Reranker(\"castorini/monot5-base-msmarco-10k\", device=\"cuda\", batch_size=128)" ] }, { @@ -88,8 +89,8 @@ "source": [ "import srsly\n", "\n", - "corpus = [x for x in srsly.read_jsonl('./data/scifact/corpus.jsonl')]\n", - "queries = [x for x in srsly.read_jsonl('./data/scifact/queries.jsonl')]\n", + "corpus = [x for x in srsly.read_jsonl(\"./data/scifact/corpus.jsonl\")]\n", + "queries = [x for x in srsly.read_jsonl(\"./data/scifact/queries.jsonl\")]\n", "\n", "corpus[0]" ] @@ -122,7 +123,7 @@ } ], "source": [ - "ranker = Reranker('mixedbread-ai/mxbai-rerank-base-v1', device='cuda')" + "ranker = Reranker(\"mixedbread-ai/mxbai-rerank-base-v1\", device=\"cuda\")" ] }, { @@ -131,7 +132,7 @@ "metadata": {}, "outputs": [], "source": [ - "top100 = srsly.read_json('data/scifact/scifact_top_100.json')\n" + "top100 = srsly.read_json(\"data/scifact/scifact_top_100.json\")" ] }, { @@ -140,7 +141,7 @@ "metadata": {}, "outputs": [], "source": [ - "corpus_map = {x['_id']: f\"{x['title']} {x['text']}\" for x in corpus}" + "corpus_map = {x[\"_id\"]: f\"{x['title']} {x['text']}\" for x in corpus}" ] }, { @@ -157,15 +158,16 @@ } ], "source": [ - "qrels_dict = dict(qrels)\n", - "queries = [q for q in queries if q['_id'] in qrels_dict]\n", "from tqdm import tqdm\n", "\n", + "qrels_dict = dict(qrels)\n", + "queries = [q for q in queries if q[\"_id\"] in qrels_dict]\n", + "\n", "scores = {}\n", "for q in tqdm(queries):\n", - " doc_ids = top100[q['_id']]\n", + " doc_ids = top100[q[\"_id\"]]\n", " docs = [corpus_map[x] for x in doc_ids]\n", - " scores[q['_id']] = ranker.rank(q['text'], docs, doc_ids=doc_ids)\n" + " scores[q[\"_id\"]] = ranker.rank(q[\"text\"], docs, doc_ids=doc_ids)" ] }, { @@ -177,7 +179,7 @@ "scores_dict = {}\n", "for q_id, ranked_results in scores.items():\n", " top_10_results = ranked_results.top_k(10)\n", - " scores_dict[q_id] = {result.doc_id: result.score for result in top_10_results}\n" + " scores_dict[q_id] = {result.doc_id: result.score for result in top_10_results}" ] }, { @@ -224,12 +226,15 @@ ], "source": [ "from ranx import evaluate\n", - "evaluation_score = evaluate(qrels, run, 'ndcg@10')\n", - "litterature_result = 0.724 # from MXBAI https://docs.google.com/spreadsheets/d/15ELkSMFv-oHa5TRiIjDvhIstH9dlc3pnZeO-iGz4Ld4/edit#gid=0\n", + "\n", + "evaluation_score = evaluate(qrels, run, \"ndcg@10\")\n", + "litterature_result = 0.724 # from MXBAI https://docs.google.com/spreadsheets/d/15ELkSMFv-oHa5TRiIjDvhIstH9dlc3pnZeO-iGz4Ld4/edit#gid=0\n", "if abs(evaluation_score - litterature_result) > 0.01:\n", - " print(f\"Score {evaluation_score:0.3f} differs by more than 0.01 from the the reported score.\")\n", + " print(\n", + " f\"Score {evaluation_score:0.3f} differs by more than 0.01 from the the reported score.\"\n", + " )\n", "else:\n", - " print(f\"Score is within 0.01NDCG@10 of the reported score!\")\n" + " print(\"Score is within 0.01NDCG@10 of the reported score!\")" ] } ], diff --git a/tests/consistency_notebooks/test_inranker.ipynb b/tests/consistency_notebooks/test_inranker.ipynb index 119bdb8..466f2f8 100644 --- a/tests/consistency_notebooks/test_inranker.ipynb +++ b/tests/consistency_notebooks/test_inranker.ipynb @@ -6,8 +6,7 @@ "metadata": {}, "outputs": [], "source": [ - "from ranx import Qrels, Run\n", - "\n" + "from ranx import Qrels, Run" ] }, { @@ -59,8 +58,8 @@ "source": [ "import srsly\n", "\n", - "corpus = [x for x in srsly.read_jsonl('./data/scifact/corpus.jsonl')]\n", - "queries = [x for x in srsly.read_jsonl('./data/scifact/queries.jsonl')]\n", + "corpus = [x for x in srsly.read_jsonl(\"./data/scifact/corpus.jsonl\")]\n", + "queries = [x for x in srsly.read_jsonl(\"./data/scifact/queries.jsonl\")]\n", "\n", "corpus[0]" ] @@ -87,7 +86,7 @@ } ], "source": [ - "ranker = Reranker('unicamp-dl/InRanker-base', device='cuda', batch_size=32, verbose=0)" + "ranker = Reranker(\"unicamp-dl/InRanker-base\", device=\"cuda\", batch_size=32, verbose=0)" ] }, { @@ -96,8 +95,7 @@ "metadata": {}, "outputs": [], "source": [ - "top100 = srsly.read_json('data/scifact/scifact_top_100.json')\n", - "\n" + "top100 = srsly.read_json(\"data/scifact/scifact_top_100.json\")" ] }, { @@ -106,7 +104,7 @@ "metadata": {}, "outputs": [], "source": [ - "corpus_map = {x['_id']: f\"{x['title']} {x['text']}\" for x in corpus}" + "corpus_map = {x[\"_id\"]: f\"{x['title']} {x['text']}\" for x in corpus}" ] }, { @@ -123,15 +121,16 @@ } ], "source": [ - "qrels_dict = dict(qrels)\n", - "queries = [q for q in queries if q['_id'] in qrels_dict]\n", "from tqdm import tqdm\n", "\n", + "qrels_dict = dict(qrels)\n", + "queries = [q for q in queries if q[\"_id\"] in qrels_dict]\n", + "\n", "scores = {}\n", "for q in tqdm(queries):\n", - " doc_ids = top100[q['_id']]\n", + " doc_ids = top100[q[\"_id\"]]\n", " docs = [corpus_map[x] for x in doc_ids]\n", - " scores[q['_id']] = ranker.rank(q['text'], docs, doc_ids=doc_ids)\n" + " scores[q[\"_id\"]] = ranker.rank(q[\"text\"], docs, doc_ids=doc_ids)" ] }, { @@ -162,12 +161,15 @@ ], "source": [ "from ranx import evaluate\n", - "evaluation_score = evaluate(qrels, run, 'ndcg@10')\n", - "litterature_result = 0.7618 # From InRanker Paper https://arxiv.org/pdf/2401.06910.pdf\n", + "\n", + "evaluation_score = evaluate(qrels, run, \"ndcg@10\")\n", + "litterature_result = 0.7618 # From InRanker Paper https://arxiv.org/pdf/2401.06910.pdf\n", "if abs(evaluation_score - litterature_result) > 0.01:\n", - " print(f\"Score {evaluation_score:0.3f} differs by more than 0.01 from the the reported score.\")\n", + " print(\n", + " f\"Score {evaluation_score:0.3f} differs by more than 0.01 from the the reported score.\"\n", + " )\n", "else:\n", - " print(f\"Score is within 0.01 NDCG@10 of the reported score!\")\n" + " print(\"Score is within 0.01 NDCG@10 of the reported score!\")" ] } ], diff --git a/tests/consistency_notebooks/test_t5.ipynb b/tests/consistency_notebooks/test_t5.ipynb index 421f037..351ab0c 100644 --- a/tests/consistency_notebooks/test_t5.ipynb +++ b/tests/consistency_notebooks/test_t5.ipynb @@ -58,8 +58,8 @@ "source": [ "import srsly\n", "\n", - "corpus = [x for x in srsly.read_jsonl('./data/scifact/corpus.jsonl')]\n", - "queries = [x for x in srsly.read_jsonl('./data/scifact/queries.jsonl')]\n", + "corpus = [x for x in srsly.read_jsonl(\"./data/scifact/corpus.jsonl\")]\n", + "queries = [x for x in srsly.read_jsonl(\"./data/scifact/queries.jsonl\")]\n", "\n", "corpus[0]" ] @@ -86,7 +86,9 @@ } ], "source": [ - "ranker = Reranker('castorini/monot5-base-msmarco-10k', device='cuda', batch_size=128, verbose=0)" + "ranker = Reranker(\n", + " \"castorini/monot5-base-msmarco-10k\", device=\"cuda\", batch_size=128, verbose=0\n", + ")" ] }, { @@ -95,8 +97,7 @@ "metadata": {}, "outputs": [], "source": [ - "top100 = srsly.read_json('data/scifact/scifact_top_100.json')\n", - "\n" + "top100 = srsly.read_json(\"data/scifact/scifact_top_100.json\")" ] }, { @@ -105,7 +106,7 @@ "metadata": {}, "outputs": [], "source": [ - "corpus_map = {x['_id']: f\"{x['title']} {x['text']}\" for x in corpus}" + "corpus_map = {x[\"_id\"]: f\"{x['title']} {x['text']}\" for x in corpus}" ] }, { @@ -122,15 +123,16 @@ } ], "source": [ - "qrels_dict = dict(qrels)\n", - "queries = [q for q in queries if q['_id'] in qrels_dict]\n", "from tqdm import tqdm\n", "\n", + "qrels_dict = dict(qrels)\n", + "queries = [q for q in queries if q[\"_id\"] in qrels_dict]\n", + "\n", "scores = {}\n", "for q in tqdm(queries):\n", - " doc_ids = top100[q['_id']]\n", + " doc_ids = top100[q[\"_id\"]]\n", " docs = [corpus_map[x] for x in doc_ids]\n", - " scores[q['_id']] = ranker.rank(q['text'], docs, doc_ids=doc_ids)\n" + " scores[q[\"_id\"]] = ranker.rank(q[\"text\"], docs, doc_ids=doc_ids)" ] }, { @@ -161,12 +163,15 @@ ], "source": [ "from ranx import evaluate\n", - "evaluation_score = evaluate(qrels, run, 'ndcg@10')\n", - "litterature_result = 0.734 # From RankGPT Paper https://arxiv.org/pdf/2304.09542.pdf\n", + "\n", + "evaluation_score = evaluate(qrels, run, \"ndcg@10\")\n", + "litterature_result = 0.734 # From RankGPT Paper https://arxiv.org/pdf/2304.09542.pdf\n", "if abs(evaluation_score - litterature_result) > 0.01:\n", - " print(f\"Score {evaluation_score:0.3f} differs by more than 0.01 from the the reported score.\")\n", + " print(\n", + " f\"Score {evaluation_score:0.3f} differs by more than 0.01 from the the reported score.\"\n", + " )\n", "else:\n", - " print(f\"Score is within 0.01 NDCG@10 of the reported score!\")\n" + " print(\"Score is within 0.01 NDCG@10 of the reported score!\")" ] } ], diff --git a/tests/test_crossenc.py b/tests/test_crossenc.py index 9637708..e149d11 100644 --- a/tests/test_crossenc.py +++ b/tests/test_crossenc.py @@ -1,10 +1,9 @@ from unittest.mock import patch -import torch -from rerankers import Reranker from rerankers.models.transformer_ranker import TransformerRanker from rerankers.results import Result, RankedResults from rerankers.documents import Document + @patch("rerankers.models.transformer_ranker.TransformerRanker.rank") def test_transformer_ranker_rank(mock_rank): query = "Gone with the wind is an absolute masterpiece" @@ -15,12 +14,17 @@ def test_transformer_ranker_rank(mock_rank): expected_results = RankedResults( results=[ Result( - document=Document(id=1, text="Gone with the wind is an all-time classic"), + document=Document( + doc_id=1, text="Gone with the wind is an all-time classic" + ), score=1.6181640625, rank=1, ), Result( - document=Document(id=0, text="Gone with the wind is a masterclass in bad storytelling."), + document=Document( + doc_id=0, + text="Gone with the wind is a masterclass in bad storytelling.", + ), score=0.88427734375, rank=2, ), diff --git a/tests/test_results.py b/tests/test_results.py index 2fbeb22..f60be44 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -6,8 +6,8 @@ def test_ranked_results_functions(): results = RankedResults( results=[ - Result(document=Document(id=0, text="Doc 0"), score=0.9, rank=2), - Result(document=Document(id=1, text="Doc 1"), score=0.95, rank=1), + Result(document=Document(doc_id=0, text="Doc 0"), score=0.9, rank=2), + Result(document=Document(doc_id=1, text="Doc 1"), score=0.95, rank=1), ], query="Test Query", has_scores=True, @@ -20,7 +20,7 @@ def test_ranked_results_functions(): def test_result_attributes(): - result = Result(document=Document(id=1, text="Doc 1"), score=0.95, rank=1) + result = Result(document=Document(doc_id=1, text="Doc 1"), score=0.95, rank=1) assert result.doc_id == 1 assert result.text == "Doc 1" assert result.score == 0.95