Skip to content
47 changes: 47 additions & 0 deletions transformations/transformer_text_generation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Transformer-based Text Generation
We use generative pretrained language model (e.g., GPT-2) to generate next word(s) in a sequence based on preceding word(s).

Comment thread
jmamou marked this conversation as resolved.
Author Name: Jonathan Mamou

Email: jonathan.mamou@intel.com

Affiliation: Intel Labs

## What type of a transformation is this?
Given a generative pretrained language model, we generate next word(s) in a sequence based on the prefix of the original text. For SST-2 and IMDB datasets, the code points to fine-tuned models, respectively, ```jmamou/gpt2-medium-IMDB``` and ```jmamou/gpt2-medium-SST-2``` (available at HugginFace Model Hub).
In order to generate a sample preserving the label of the original sample, we fine-tuned ```gpt2-medium``` for labeled text generation tasks. If the sentiment label of the orignial text is not provided, we first run sentiment classification to get a pseudo-label using respectively, ```textattack/roberta-base-imdb``` and ```textattack/roberta-base-SST-2```.

In addition, we support general purpose unlabeled text generation using ```gpt2-xl``` pretrained model.

Note that ```num_return_sequences``` parameter of ```generate``` method should be set to _n_ in order to predict top _n_ predictions (instead of one).



## What tasks does it intend to benefit?
This transformation would benefit all tasks which have a sentence/paragraph/document as input like text classification, text generation and text tagging. In addition, this approach has been successfully used to augment data for distillation.

```python evaluate.py -t TransformerTextGeneration -task TEXT_CLASSIFICATION```
```model_name = "textattack/roberta-base-SST-2" -d sst2```

## What are the limitations of this transformation?
In order to get high quality augmented data, we need a generative pretrained language model trained or fine-tuned with data from the domain.

## Previous Work
Our approach follows previous work on data augmentation for distillation
```bibtex
@inproceedings{tang-etal-2019-natural,
title = "Natural Language Generation for Effective Knowledge Distillation",
author = "Tang, Raphael and
Lu, Yao and
Lin, Jimmy",
booktitle = "Proceedings of the 2nd Workshop on Deep Learning Approaches for Low-Resource NLP (DeepLo 2019)",
month = nov,
year = "2019",
address = "Hong Kong, China",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/D19-6122",
doi = "10.18653/v1/D19-6122",
pages = "202--208",
abstract = "Knowledge distillation can effectively transfer knowledge from BERT, a deep language representation model, to traditional, shallow word embedding-based neural networks, helping them approach or exceed the quality of other heavyweight language representation models. As shown in previous work, critical to this distillation procedure is the construction of an unlabeled transfer dataset, which enables effective knowledge transfer. To create transfer set examples, we propose to sample from pretrained language models fine-tuned on task-specific text. Unlike previous techniques, this directly captures the purpose of the transfer set. We hypothesize that this principled, general approach outperforms rule-based techniques. On four datasets in sentiment classification, sentence similarity, and linguistic acceptability, we show that our approach improves upon previous methods. We outperform OpenAI GPT, a deep pretrained transformer, on three of the datasets, while using a single-layer bidirectional LSTM that runs at least ten times faster.",
}
```
1 change: 1 addition & 0 deletions transformations/transformer_text_generation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .transformation import *
60 changes: 60 additions & 0 deletions transformations/transformer_text_generation/test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
{
"type": "transformer_text_generation",
"test_cases": [
{
"class": "TransformerTextGeneration",
"inputs": {
"sentence": "Andrew finally returned the French book to Chris that I bought last week"
},
"outputs": [
{
"sentence": "Andrew finally returned the French book to its native land with a masterful work of quiet power."
}
]
},
{
"class": "TransformerTextGeneration",
"inputs": {
"sentence": "Sentences with gapping, such as Paul likes coffee and Mary tea, lack an overt predicate to indicate the relation between two or more arguments."
},
"outputs": [
{
"sentence": "Sentences with gapping, such as Paul likes coffee and Mary tea, lack the punch of their titles."
}
]
},
{
"class": "TransformerTextGeneration",
"inputs": {
"sentence": "Alice in Wonderland is a 2010 American live-action/animated dark fantasy adventure film"
},
"outputs": [
{
"sentence": "Alice in Wonderland is a 2010 american musical that maintains an almost constant state of heightened alert."
}
]
},
{
"class": "TransformerTextGeneration",
"inputs": {
"sentence": "Ujjal Dev Dosanjh served as 33rd Premier of British Columbia from 2000 to 2001"
},
"outputs": [
{
"sentence": "Ujjal Dev Dosanjh served as 33rd Premier of Gujarat."
}
]
},
{
"class": "TransformerTextGeneration",
"inputs": {
"sentence": "Neuroplasticity is a continuous processing allowing short-term, medium-term, and long-term remodeling of the neuronosynaptic organization."
},
"outputs": [
{
"sentence": "Neuroplasticity is a continuous processing allowing short-term, medium-term, and long-range memory to play out simultaneously, without the need for episodic rewiring."
}
]
}
]
}
167 changes: 167 additions & 0 deletions transformations/transformer_text_generation/transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import logging
import math

import numpy as np
import torch
from transformers import GPT2LMHeadModel, pipeline

from interfaces.SentenceOperation import SentenceOperation
from tasks.TaskTypes import TaskType

logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)


def set_seed(seed: int, no_cuda: bool):
np.random.seed(seed)
torch.manual_seed(seed)
if not no_cuda:
torch.cuda.manual_seed_all(seed)


class TransformerTextGeneration(SentenceOperation):
Comment thread
jmamou marked this conversation as resolved.
Comment thread
jmamou marked this conversation as resolved.
tasks = [
TaskType.TEXT_CLASSIFICATION,
TaskType.TEXT_TO_TEXT_GENERATION,
TaskType.TEXT_TAGGING,
]
languages = ["en"]

Comment thread
jmamou marked this conversation as resolved.
def __init__(
self,
eos: str = "</s>",
no_cuda: bool = False,
dataset="sst2",
labeled=True,
seed=42,
):

Comment thread
jmamou marked this conversation as resolved.
Outdated
set_seed(seed, no_cuda)

if labeled:
if dataset == "imdb":
model_text_generation: str = "jmamou/gpt2-medium-IMDB"
# model_sentiment_classification = 'aychang/roberta-base-imdb'
Comment thread
jmamou marked this conversation as resolved.
Outdated
model_sentiment_classification = "textattack/roberta-base-imdb"
elif dataset == "sst2":
model_text_generation: str = "jmamou/gpt2-medium-SST-2"
model_sentiment_classification = (
"textattack/roberta-base-SST-2"
)
else:
model_text_generation: str = "gpt2-xl"
model_sentiment_classification = None

self.eos = eos
device = -1 if no_cuda else 0

# initialize text generation pipeline
model = GPT2LMHeadModel.from_pretrained(model_text_generation)
Comment thread
jmamou marked this conversation as resolved.
self.text_generator = pipeline(
"text-generation",
model=model,
tokenizer=model_text_generation,
device=device,
)

# if relevant, initialize the sentiment classification pipeline
if model_sentiment_classification is not None:
Comment thread
jmamou marked this conversation as resolved.
self.label_name = "label"
self.text_classifier = pipeline(
"sentiment-analysis",
model=model_sentiment_classification,
tokenizer=model_sentiment_classification,
device=device,
)
else:
self.text_classifier = None

def generate(
self,
sequence: str,
num_return_sequences: int = 1,
prefix_ratio: float = 0.5,
max_length_factor=3,
max_prefix_length=400,
model_max_length=512,
temperature: float = 1.0,
repetition_penalty: float = 1.2,
k: int = 0,
p: float = 0.9,
):

logger.info("original text: " + sequence)
augmented_texts = []
sequence = sequence.split()

prefix_length = min(
max_prefix_length, math.ceil(len(sequence) * prefix_ratio)
)
if self.text_classifier is None:
text_inputs = " ".join(sequence[0:prefix_length])
else:
label = self.text_classifier(sequence, truncation=True)[0][
self.label_name
]
label = label.split("_")[1]
text_inputs = label + "\t" + " ".join(sequence[0:prefix_length])

max_length = min(model_max_length, len(sequence) * max_length_factor)
output_sequences = self.text_generator(
text_inputs=text_inputs,
temperature=temperature,
top_k=k,
top_p=p,
repetition_penalty=repetition_penalty,
do_sample=True,
num_return_sequences=num_return_sequences,
clean_up_tokenization_spaces=True,
return_full_text=True,
max_length=max_length,
truncation=True,
)
for seq in output_sequences:
text = seq["generated_text"]
if self.text_classifier is not None:
text = text.split("\t")[1]
text = text[
: text.find(self.eos)
if self.eos and text.find(self.eos) > -1
else None
].strip()
text = text.replace("\n", ".")
augmented_texts.append(text)
logger.info("augmented text: " + text)
return augmented_texts


if __name__ == "__main__":
import json

from TestRunner import convert_to_snake_case

tf = TransformerTextGeneration()
test_cases = []
for sentence in [
"Andrew finally returned the French book to Chris that I bought last week",
"Sentences with gapping, such as Paul likes coffee and Mary tea, lack an overt predicate to indicate the relation between two or more arguments.",
"Alice in Wonderland is a 2010 American live-action/animated dark fantasy adventure film",
"Ujjal Dev Dosanjh served as 33rd Premier of British Columbia from 2000 to 2001",
"Neuroplasticity is a continuous processing allowing short-term, medium-term, and long-term remodeling of the neuronosynaptic organization.",
]:
test_cases.append(
{
"class": tf.name(),
"inputs": {"sentence": sentence},
"outputs": [{"sentence": o} for o in tf.generate(sentence)],
}
)
json_file = {
"type": convert_to_snake_case(tf.name()),
"test_cases": test_cases,
}
print(json.dumps(json_file))