-
Notifications
You must be signed in to change notification settings - Fork 667
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added text labeling interface using zero shot classifier, closes #30
- Loading branch information
1 parent
816e652
commit bd35445
Showing
2 changed files
with
58 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
""" | ||
Labels module | ||
""" | ||
|
||
import torch | ||
|
||
from transformers import pipeline | ||
|
||
class Labels(object): | ||
""" | ||
Applies labels to text sections using a zero shot classifier. | ||
""" | ||
|
||
def __init__(self, path=None): | ||
""" | ||
Creates a new Labels instance. | ||
Args: | ||
path: path to transformer model, if not provided uses a default model | ||
""" | ||
|
||
self.classifier = pipeline("zero-shot-classification", model=path, tokenizer=path, | ||
device=0 if torch.cuda.is_available() else -1) | ||
|
||
def __call__(self, section, labels): | ||
""" | ||
Applies a zero shot classifier to a text section using a list of labels. | ||
Args: | ||
section: text section | ||
labels: list of labels | ||
Returns: | ||
list of (label, score) for section | ||
""" | ||
|
||
result = self.classifier(section, labels) | ||
return list(zip(result["labels"], result["scores"])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
""" | ||
Labels module tests | ||
""" | ||
|
||
import unittest | ||
|
||
from txtai.labels import Labels | ||
|
||
class TestLabels(unittest.TestCase): | ||
""" | ||
Labels tests | ||
""" | ||
|
||
def testLabel(self): | ||
""" | ||
Tests labeling | ||
""" | ||
|
||
labels = Labels() | ||
self.assertIsNotNone(labels("This is a test", "test")) |