From 13bafb01c6fb7c8596395b73e14aa7aea753e608 Mon Sep 17 00:00:00 2001 From: davidmezzetti <561939+davidmezzetti@users.noreply.github.com> Date: Fri, 31 Jan 2025 10:37:02 -0500 Subject: [PATCH] Add support for GLiNER models, closes #862 --- src/python/txtai/pipeline/text/entity.py | 78 ++++++++++++++++++- test/python/testoptional.py | 5 ++ .../testpipeline/testtext/testentity.py | 9 +++ 3 files changed, 90 insertions(+), 2 deletions(-) diff --git a/src/python/txtai/pipeline/text/entity.py b/src/python/txtai/pipeline/text/entity.py index ebd88c0cd..61dde68bc 100644 --- a/src/python/txtai/pipeline/text/entity.py +++ b/src/python/txtai/pipeline/text/entity.py @@ -2,6 +2,17 @@ Entity module """ +# Conditional import +try: + from gliner import GLiNER + + GLINER = True +except ImportError: + GLINER = False + +from transformers.utils import cached_file + +from ...models import Models from ..hfpipeline import HFPipeline @@ -11,7 +22,18 @@ class Entity(HFPipeline): """ def __init__(self, path=None, quantize=False, gpu=True, model=None, **kwargs): - super().__init__("token-classification", path, quantize, gpu, model, **kwargs) + # Create a new entity pipeline + self.gliner = self.isgliner(path) + if self.gliner: + if not GLINER: + raise ImportError('GLiNER is not available - install "pipeline" extra to enable') + + # GLiNER entity pipeline + self.pipeline = GLiNER.from_pretrained(path) + self.pipeline = self.pipeline.to(Models.device(Models.deviceid(gpu))) + else: + # Standard entity pipeline + super().__init__("token-classification", path, quantize, gpu, model, **kwargs) def __call__(self, text, labels=None, aggregate="simple", flatten=None, join=False, workers=0): """ @@ -30,7 +52,7 @@ def __call__(self, text, labels=None, aggregate="simple", flatten=None, join=Fal """ # Run token classification pipeline - results = self.pipeline(text, aggregation_strategy=aggregate, num_workers=workers) + results = self.execute(text, labels, aggregate, workers) # Convert results to a list if necessary if isinstance(text, str): @@ -50,6 +72,58 @@ def __call__(self, text, labels=None, aggregate="simple", flatten=None, join=Fal return outputs[0] if isinstance(text, str) else outputs + def isgliner(self, path): + """ + Tests if path is a GLiNER model. + + Args: + path: model path + + Returns: + True if this is a GLiNER model, False otherwise + """ + + try: + # Test if this model has a gliner_config.json file + return cached_file(path_or_repo_id=path, filename="gliner_config.json") is not None + + # Ignore this error - invalid repo or directory + except OSError: + pass + + return False + + def execute(self, text, labels, aggregate, workers): + """ + Runs the entity extraction pipeline. + + Args: + text: text|list + labels: list of entity type labels to accept, defaults to None which accepts all + aggregate: method to combine multi token entities - options are "simple" (default), "first", "average" or "max" + workers: number of concurrent workers to use for processing data, defaults to None + + Returns: + list of entities and labels + """ + + if self.gliner: + # Extract entities with GLiNER. Use default CoNLL-2003 labels when not otherwise provided. + results = self.pipeline.batch_predict_entities( + text if isinstance(text, list) else [text], labels if labels else ["person", "organization", "location"] + ) + + # Map results to same format as Transformers token classifier + entities = [] + for result in results: + entities.append([{"word": x["text"], "entity_group": x["label"], "score": x["score"]} for x in result]) + + # Return extracted entities + return entities if isinstance(text, list) else entities[0] + + # Standard Transformers token classification pipeline + return self.pipeline(text, aggregation_strategy=aggregate, num_workers=workers) + def accept(self, etype, labels): """ Determines if entity type is in valid entity type. diff --git a/test/python/testoptional.py b/test/python/testoptional.py index 74ebfc2d6..d54980e3b 100644 --- a/test/python/testoptional.py +++ b/test/python/testoptional.py @@ -27,6 +27,7 @@ def setUpClass(cls): "docling.document_converter", "duckdb", "fastapi", + "gliner", "grand-cypher", "grand-graph", "hnswlib", @@ -186,6 +187,7 @@ def testPipeline(self): AudioMixer, AudioStream, Caption, + Entity, FileToHTML, HFOnnx, HFTrainer, @@ -213,6 +215,9 @@ def testPipeline(self): with self.assertRaises(ImportError): Caption() + with self.assertRaises(ImportError): + Entity("neuml/gliner-bert-tiny") + with self.assertRaises(ImportError): FileToHTML(backend="docling") diff --git a/test/python/testpipeline/testtext/testentity.py b/test/python/testpipeline/testtext/testentity.py index f722f040f..9868f7ed2 100644 --- a/test/python/testpipeline/testtext/testentity.py +++ b/test/python/testpipeline/testtext/testentity.py @@ -52,3 +52,12 @@ def testEntityTypes(self): # Run entity extraction entities = self.entity("Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", labels=["PER"]) self.assertFalse(entities) + + def testGliner(self): + """ + Test entity pipeline with a GLiNER model + """ + + entity = Entity("neuml/gliner-bert-tiny") + entities = entity("My name is John Smith.", flatten=True) + self.assertEqual(entities, ["John Smith"])