diff --git a/pyproject.toml b/pyproject.toml index 68cc372..2e102b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,7 @@ rankllm = [ "nmslib-metabrainz; python_version >= '3.10'", "rank-llm; python_version >= '3.10'" ] +fastembed = ["fastembed"] dev = ["ruff", "isort", "pytest", "ipyprogress", "ipython", "ranx", "ir_datasets", "srsly"] [project.urls] diff --git a/rerankers/models/__init__.py b/rerankers/models/__init__.py index 15d81d4..6922cd9 100644 --- a/rerankers/models/__init__.py +++ b/rerankers/models/__init__.py @@ -58,3 +58,10 @@ AVAILABLE_RANKERS["MonoVLMRanker"] = MonoVLMRanker except ImportError: pass + +try: + from rerankers.models.fastembed_ranker import FastEmbedRanker + + AVAILABLE_RANKERS["FastEmbedRanker"] = FastEmbedRanker +except ImportError: + pass diff --git a/rerankers/models/fastembed_ranker.py b/rerankers/models/fastembed_ranker.py new file mode 100644 index 0000000..2de9efd --- /dev/null +++ b/rerankers/models/fastembed_ranker.py @@ -0,0 +1,28 @@ +from rerankers.results import RankedResults +from rerankers.models.ranker import BaseRanker +from rerankers.results import RankedResults, Result +from fastembed.rerank.cross_encoder import TextCrossEncoder +from rerankers.utils import prep_docs + + +class FastEmbedRanker(BaseRanker): + + def __init__(self, model_name_or_path, verbose=None): + + self.model = TextCrossEncoder(model_name=model_name_or_path) + + def rank(self, query, docs): + docs = prep_docs(docs) + scores = list(self.model.rerank(query, [d.text for d in docs])) + indices = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) + + ranked_results = [ + Result(document=docs[idx], score=scores[idx], rank=i + 1) + for i, idx in enumerate(indices) + ] + + return RankedResults(results=ranked_results, query=query, has_scores=True) + + def score(self, query, doc): + score = list(self.model.rerank(query, [doc]))[0] + return score diff --git a/rerankers/reranker.py b/rerankers/reranker.py index 9e9bb02..9a8434a 100644 --- a/rerankers/reranker.py +++ b/rerankers/reranker.py @@ -38,8 +38,12 @@ }, "monovlm": { "en": "lightonai/MonoQwen2-VL-v0.1", - "other": "lightonai/MonoQwen2-VL-v0.1" - } + "other": "lightonai/MonoQwen2-VL-v0.1", + }, + "fastembed": { + "en": "Xenova/ms-marco-MiniLM-L-6-v2", + "other": "Xenova/ms-marco-MiniLM-L-6-v2", + }, } DEPS_MAPPING = { @@ -52,7 +56,8 @@ "FlashRankRanker": "flashrank", "RankLLMRanker": "rankllm", "LLMLayerWiseRanker": "transformers", - "MonoVLMRanker": "transformers" + "MonoVLMRanker": "transformers", + "FastEmbedRanker": "fastembed", } PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai", "pinecone", "text-embeddings-inference"] @@ -91,7 +96,8 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None) "flashrank": "FlashRankRanker", "rankllm": "RankLLMRanker", "llm-layerwise": "LLMLayerWiseRanker", - "monovlm": "MonoVLMRanker" + "monovlm": "MonoVLMRanker", + "fastembed": "FastEmbedRanker", } return model_mapping.get(explicit_model_type, explicit_model_type) else: @@ -115,7 +121,8 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None) "zephyr": "RankLLMRanker", "bge-reranker-v2.5-gemma2-lightweight": "LLMLayerWiseRanker", "monovlm": "MonoVLMRanker", - "monoqwen2-vl": "MonoVLMRanker" + "monoqwen2-vl": "MonoVLMRanker", + "Xenova/ms-marco-MiniLM-L-6-v2": "FastEmbedRanker", } for key, value in model_mapping.items(): if key in model_name: