Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions transformations/bert_sent_mask_fill/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# BERT Sentence Mask Filling 🦎 + 🎭 → 🐍
This transformation generates similar sentences by applying BERT mask filling to keywords.

Author name: Het Pandya

Author email: hetpandya6797@gmail.com

## What type of a transformation is this?
This transformation augments text using mask filling to replace keywords with other words. For that, the words that can be masked are found using spacy to extract keywords from a sentence. Once the keywords are found, they are replaced with a mask and fed to the BERT model to predict a word in place of the masked word.

## What tasks does it intend to benefit?
This transformation can help generate synthetic data where the number of samples is less in amount.

## What are the limitations of this transformation?
Although, the transformation can generate a healthy amount of similar samples, they will be very simple as compared to the outputs of a paraphraser.
2 changes: 2 additions & 0 deletions transformations/bert_sent_mask_fill/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .transformation import *

4 changes: 4 additions & 0 deletions transformations/bert_sent_mask_fill/test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "bert_sent_mask_fill",
"test_cases": []
}
91 changes: 91 additions & 0 deletions transformations/bert_sent_mask_fill/transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from transformers import pipeline
from string import punctuation
import spacy
from interfaces.SentenceOperation import SentenceOperation
from tasks.TaskTypes import TaskType
from typing import List
import torch
import random

"""
The following transformation augments text using mask filling to replace keywords with other words.
For that, the words that can be masked are found using spacy to extract keywords from a sentence.
Once the keywords are found, they are replaced with a mask and fed to the BERT model to predict a word in place of the masked word.
"""


def set_seed(seed):
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)


class BertSentenceMaskFilling(SentenceOperation):
tasks = [TaskType.TEXT_CLASSIFICATION, TaskType.TEXT_TO_TEXT_GENERATION]

heavy = True
languages = ["en"]
"""
Could support other languages if there is a pretrained BERT mask fill model and
SpaCy model available for the same.
"""

def __init__(
self,
mask_model_name="bert-base-uncased",
spacy_model_name="en_core_web_sm",
n_mask_predictions=3,
seed=0,
):
super().__init__(seed=seed)

self.n_mask_predictions = n_mask_predictions

if self.verbose:
print("Loading BERT Mask Fill Model..\n")

self.mask_augmenter = pipeline("fill-mask", model=mask_model_name)

if self.verbose:
print("Loading SpaCy Model..\n")

self.nlp = spacy.load(spacy_model_name)
set_seed(seed)

if self.verbose:
print("Completed loading BERT Mask Fill and SpaCy Models..\n")

def extract_keywords(self, sentence):
result = []
pos_tag = ["PROPN", "NOUN", "ADJ", "NUM"]

doc = self.nlp(sentence)

for token in doc:
if (
token.text in self.nlp.Defaults.stop_words or token.text in punctuation
) and token.pos_ not in consider_tags:
continue
if token.pos_ in pos_tag:
result.append(token.text)
return list(set(result))

def generate(self, sentence: str) -> List[str]:
keywords = self.extract_keywords(sentence)
augmented_sents = []
for keyword in keywords:
masked_sent = sentence.replace(
keyword, self.mask_augmenter.tokenizer.mask_token, 1
)
augmented_sents.extend(
[
generated_sent["sequence"].capitalize()
for generated_sent in self.mask_augmenter(
masked_sent, top_k=self.n_mask_predictions
)
if generated_sent["sequence"].lower() != sentence.lower()
]
)
return augmented_sents