-
Notifications
You must be signed in to change notification settings - Fork 197
Added transformer_text_generation #276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jmamou
wants to merge
11
commits into
GEM-benchmark:main
Choose a base branch
from
jmamou:transformer_text_generation
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 6 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
6459dd3
Added transformer_text_generation
jmamou 503246a
Added transformer_text_generation
jmamou 6fcb273
Added transformer_text_generation
jmamou 9c33ae7
Added transformer_text_generation
jmamou 9b13207
add name, email and affiliation
jmamou be1f274
Added limitations and top n generation to README
jmamou b4bf00f
add docstrings, init call to super class, heavy transformation, add k…
jmamou ff9cda6
adding evaluation
jmamou 7f64553
adding evaluation
jmamou 7891328
Add transformer_text_generation to test/mapper.py
jmamou 706a813
Merge branch 'main' into transformer_text_generation
jmamou File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| # Transformer-based Text Generation | ||
| We use generative pretrained language model (e.g., GPT-2) to generate next word(s) in a sequence based on preceding word(s). | ||
|
|
||
| Author Name: Jonathan Mamou | ||
|
|
||
| Email: jonathan.mamou@intel.com | ||
|
|
||
| Affiliation: Intel Labs | ||
|
|
||
| ## What type of a transformation is this? | ||
| Given a generative pretrained language model, we generate next word(s) in a sequence based on the prefix of the original text. For SST-2 and IMDB datasets, the code points to fine-tuned models, respectively, ```jmamou/gpt2-medium-IMDB``` and ```jmamou/gpt2-medium-SST-2``` (available at HugginFace Model Hub). | ||
| In order to generate a sample preserving the label of the original sample, we fine-tuned ```gpt2-medium``` for labeled text generation tasks. If the sentiment label of the orignial text is not provided, we first run sentiment classification to get a pseudo-label using respectively, ```textattack/roberta-base-imdb``` and ```textattack/roberta-base-SST-2```. | ||
|
|
||
| In addition, we support general purpose unlabeled text generation using ```gpt2-xl``` pretrained model. | ||
|
|
||
| Note that ```num_return_sequences``` parameter of ```generate``` method should be set to _n_ in order to predict top _n_ predictions (instead of one). | ||
|
|
||
|
|
||
|
|
||
| ## What tasks does it intend to benefit? | ||
| This transformation would benefit all tasks which have a sentence/paragraph/document as input like text classification, text generation and text tagging. In addition, this approach has been successfully used to augment data for distillation. | ||
|
|
||
| ```python evaluate.py -t TransformerTextGeneration -task TEXT_CLASSIFICATION``` | ||
| ```model_name = "textattack/roberta-base-SST-2" -d sst2``` | ||
|
|
||
| ## What are the limitations of this transformation? | ||
| In order to get high quality augmented data, we need a generative pretrained language model trained or fine-tuned with data from the domain. | ||
|
|
||
| ## Previous Work | ||
| Our approach follows previous work on data augmentation for distillation | ||
| ```bibtex | ||
| @inproceedings{tang-etal-2019-natural, | ||
| title = "Natural Language Generation for Effective Knowledge Distillation", | ||
| author = "Tang, Raphael and | ||
| Lu, Yao and | ||
| Lin, Jimmy", | ||
| booktitle = "Proceedings of the 2nd Workshop on Deep Learning Approaches for Low-Resource NLP (DeepLo 2019)", | ||
| month = nov, | ||
| year = "2019", | ||
| address = "Hong Kong, China", | ||
| publisher = "Association for Computational Linguistics", | ||
| url = "https://aclanthology.org/D19-6122", | ||
| doi = "10.18653/v1/D19-6122", | ||
| pages = "202--208", | ||
| abstract = "Knowledge distillation can effectively transfer knowledge from BERT, a deep language representation model, to traditional, shallow word embedding-based neural networks, helping them approach or exceed the quality of other heavyweight language representation models. As shown in previous work, critical to this distillation procedure is the construction of an unlabeled transfer dataset, which enables effective knowledge transfer. To create transfer set examples, we propose to sample from pretrained language models fine-tuned on task-specific text. Unlike previous techniques, this directly captures the purpose of the transfer set. We hypothesize that this principled, general approach outperforms rule-based techniques. On four datasets in sentiment classification, sentence similarity, and linguistic acceptability, we show that our approach improves upon previous methods. We outperform OpenAI GPT, a deep pretrained transformer, on three of the datasets, while using a single-layer bidirectional LSTM that runs at least ten times faster.", | ||
| } | ||
| ``` | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .transformation import * |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| { | ||
| "type": "transformer_text_generation", | ||
| "test_cases": [ | ||
| { | ||
| "class": "TransformerTextGeneration", | ||
| "inputs": { | ||
| "sentence": "Andrew finally returned the French book to Chris that I bought last week" | ||
| }, | ||
| "outputs": [ | ||
| { | ||
| "sentence": "Andrew finally returned the French book to its native land with a masterful work of quiet power." | ||
| } | ||
| ] | ||
| }, | ||
| { | ||
| "class": "TransformerTextGeneration", | ||
| "inputs": { | ||
| "sentence": "Sentences with gapping, such as Paul likes coffee and Mary tea, lack an overt predicate to indicate the relation between two or more arguments." | ||
| }, | ||
| "outputs": [ | ||
| { | ||
| "sentence": "Sentences with gapping, such as Paul likes coffee and Mary tea, lack the punch of their titles." | ||
| } | ||
| ] | ||
| }, | ||
| { | ||
| "class": "TransformerTextGeneration", | ||
| "inputs": { | ||
| "sentence": "Alice in Wonderland is a 2010 American live-action/animated dark fantasy adventure film" | ||
| }, | ||
| "outputs": [ | ||
| { | ||
| "sentence": "Alice in Wonderland is a 2010 american musical that maintains an almost constant state of heightened alert." | ||
| } | ||
| ] | ||
| }, | ||
| { | ||
| "class": "TransformerTextGeneration", | ||
| "inputs": { | ||
| "sentence": "Ujjal Dev Dosanjh served as 33rd Premier of British Columbia from 2000 to 2001" | ||
| }, | ||
| "outputs": [ | ||
| { | ||
| "sentence": "Ujjal Dev Dosanjh served as 33rd Premier of Gujarat." | ||
| } | ||
| ] | ||
| }, | ||
| { | ||
| "class": "TransformerTextGeneration", | ||
| "inputs": { | ||
| "sentence": "Neuroplasticity is a continuous processing allowing short-term, medium-term, and long-term remodeling of the neuronosynaptic organization." | ||
| }, | ||
| "outputs": [ | ||
| { | ||
| "sentence": "Neuroplasticity is a continuous processing allowing short-term, medium-term, and long-range memory to play out simultaneously, without the need for episodic rewiring." | ||
| } | ||
| ] | ||
| } | ||
| ] | ||
| } |
167 changes: 167 additions & 0 deletions
167
transformations/transformer_text_generation/transformation.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,167 @@ | ||
| import logging | ||
| import math | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from transformers import GPT2LMHeadModel, pipeline | ||
|
|
||
| from interfaces.SentenceOperation import SentenceOperation | ||
| from tasks.TaskTypes import TaskType | ||
|
|
||
| logging.basicConfig( | ||
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | ||
| datefmt="%m/%d/%Y %H:%M:%S", | ||
| level=logging.INFO, | ||
| ) | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def set_seed(seed: int, no_cuda: bool): | ||
| np.random.seed(seed) | ||
| torch.manual_seed(seed) | ||
| if not no_cuda: | ||
| torch.cuda.manual_seed_all(seed) | ||
|
|
||
|
|
||
| class TransformerTextGeneration(SentenceOperation): | ||
|
jmamou marked this conversation as resolved.
jmamou marked this conversation as resolved.
|
||
| tasks = [ | ||
| TaskType.TEXT_CLASSIFICATION, | ||
| TaskType.TEXT_TO_TEXT_GENERATION, | ||
| TaskType.TEXT_TAGGING, | ||
| ] | ||
| languages = ["en"] | ||
|
|
||
|
jmamou marked this conversation as resolved.
|
||
| def __init__( | ||
| self, | ||
| eos: str = "</s>", | ||
| no_cuda: bool = False, | ||
| dataset="sst2", | ||
| labeled=True, | ||
| seed=42, | ||
| ): | ||
|
|
||
|
jmamou marked this conversation as resolved.
Outdated
|
||
| set_seed(seed, no_cuda) | ||
|
|
||
| if labeled: | ||
| if dataset == "imdb": | ||
| model_text_generation: str = "jmamou/gpt2-medium-IMDB" | ||
| # model_sentiment_classification = 'aychang/roberta-base-imdb' | ||
|
jmamou marked this conversation as resolved.
Outdated
|
||
| model_sentiment_classification = "textattack/roberta-base-imdb" | ||
| elif dataset == "sst2": | ||
| model_text_generation: str = "jmamou/gpt2-medium-SST-2" | ||
| model_sentiment_classification = ( | ||
| "textattack/roberta-base-SST-2" | ||
| ) | ||
| else: | ||
| model_text_generation: str = "gpt2-xl" | ||
| model_sentiment_classification = None | ||
|
|
||
| self.eos = eos | ||
| device = -1 if no_cuda else 0 | ||
|
|
||
| # initialize text generation pipeline | ||
| model = GPT2LMHeadModel.from_pretrained(model_text_generation) | ||
|
jmamou marked this conversation as resolved.
|
||
| self.text_generator = pipeline( | ||
| "text-generation", | ||
| model=model, | ||
| tokenizer=model_text_generation, | ||
| device=device, | ||
| ) | ||
|
|
||
| # if relevant, initialize the sentiment classification pipeline | ||
| if model_sentiment_classification is not None: | ||
|
jmamou marked this conversation as resolved.
|
||
| self.label_name = "label" | ||
| self.text_classifier = pipeline( | ||
| "sentiment-analysis", | ||
| model=model_sentiment_classification, | ||
| tokenizer=model_sentiment_classification, | ||
| device=device, | ||
| ) | ||
| else: | ||
| self.text_classifier = None | ||
|
|
||
| def generate( | ||
| self, | ||
| sequence: str, | ||
| num_return_sequences: int = 1, | ||
| prefix_ratio: float = 0.5, | ||
| max_length_factor=3, | ||
| max_prefix_length=400, | ||
| model_max_length=512, | ||
| temperature: float = 1.0, | ||
| repetition_penalty: float = 1.2, | ||
| k: int = 0, | ||
| p: float = 0.9, | ||
| ): | ||
|
|
||
| logger.info("original text: " + sequence) | ||
| augmented_texts = [] | ||
| sequence = sequence.split() | ||
|
|
||
| prefix_length = min( | ||
| max_prefix_length, math.ceil(len(sequence) * prefix_ratio) | ||
| ) | ||
| if self.text_classifier is None: | ||
| text_inputs = " ".join(sequence[0:prefix_length]) | ||
| else: | ||
| label = self.text_classifier(sequence, truncation=True)[0][ | ||
| self.label_name | ||
| ] | ||
| label = label.split("_")[1] | ||
| text_inputs = label + "\t" + " ".join(sequence[0:prefix_length]) | ||
|
|
||
| max_length = min(model_max_length, len(sequence) * max_length_factor) | ||
| output_sequences = self.text_generator( | ||
| text_inputs=text_inputs, | ||
| temperature=temperature, | ||
| top_k=k, | ||
| top_p=p, | ||
| repetition_penalty=repetition_penalty, | ||
| do_sample=True, | ||
| num_return_sequences=num_return_sequences, | ||
| clean_up_tokenization_spaces=True, | ||
| return_full_text=True, | ||
| max_length=max_length, | ||
| truncation=True, | ||
| ) | ||
| for seq in output_sequences: | ||
| text = seq["generated_text"] | ||
| if self.text_classifier is not None: | ||
| text = text.split("\t")[1] | ||
| text = text[ | ||
| : text.find(self.eos) | ||
| if self.eos and text.find(self.eos) > -1 | ||
| else None | ||
| ].strip() | ||
| text = text.replace("\n", ".") | ||
| augmented_texts.append(text) | ||
| logger.info("augmented text: " + text) | ||
| return augmented_texts | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import json | ||
|
|
||
| from TestRunner import convert_to_snake_case | ||
|
|
||
| tf = TransformerTextGeneration() | ||
| test_cases = [] | ||
| for sentence in [ | ||
| "Andrew finally returned the French book to Chris that I bought last week", | ||
| "Sentences with gapping, such as Paul likes coffee and Mary tea, lack an overt predicate to indicate the relation between two or more arguments.", | ||
| "Alice in Wonderland is a 2010 American live-action/animated dark fantasy adventure film", | ||
| "Ujjal Dev Dosanjh served as 33rd Premier of British Columbia from 2000 to 2001", | ||
| "Neuroplasticity is a continuous processing allowing short-term, medium-term, and long-term remodeling of the neuronosynaptic organization.", | ||
| ]: | ||
| test_cases.append( | ||
| { | ||
| "class": tf.name(), | ||
| "inputs": {"sentence": sentence}, | ||
| "outputs": [{"sentence": o} for o in tf.generate(sentence)], | ||
| } | ||
| ) | ||
| json_file = { | ||
| "type": convert_to_snake_case(tf.name()), | ||
| "test_cases": test_cases, | ||
| } | ||
| print(json.dumps(json_file)) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.