Skip to content

Commit

Permalink
Added text labeling interface using zero shot classifier, closes #30
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Oct 5, 2020
1 parent 816e652 commit bd35445
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/python/txtai/labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Labels module
"""

import torch

from transformers import pipeline

class Labels(object):
"""
Applies labels to text sections using a zero shot classifier.
"""

def __init__(self, path=None):
"""
Creates a new Labels instance.
Args:
path: path to transformer model, if not provided uses a default model
"""

self.classifier = pipeline("zero-shot-classification", model=path, tokenizer=path,
device=0 if torch.cuda.is_available() else -1)

def __call__(self, section, labels):
"""
Applies a zero shot classifier to a text section using a list of labels.
Args:
section: text section
labels: list of labels
Returns:
list of (label, score) for section
"""

result = self.classifier(section, labels)
return list(zip(result["labels"], result["scores"]))
20 changes: 20 additions & 0 deletions test/python/testlabels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Labels module tests
"""

import unittest

from txtai.labels import Labels

class TestLabels(unittest.TestCase):
"""
Labels tests
"""

def testLabel(self):
"""
Tests labeling
"""

labels = Labels()
self.assertIsNotNone(labels("This is a test", "test"))

0 comments on commit bd35445

Please sign in to comment.