Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions examples/emotion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Emotion

## Authors

**Armando Fortes**

Homepage: https://atfortes.github.io/

Contact: [email protected]

## Task Description

Emotion is a dataset of English Twitter messages with six basic emotions: anger, fear, joy, love, sadness, and surprise. We follow the train-validation-test split configuration from [Huggingface](https://huggingface.co/datasets/emotion). Therefore, we use 16000 samples for training, 2000 samples for validation, and 2000 samples for testing. The goal of the task is: given an English Twitter message, classify whether it is shows sadness, joy, love, anger, fear, or surprise.

We perform prompt-based fine-tuning on the ```glm-roberta-large``` model and use prompt templates from [promptsource](https://github.com/bigscience-workshop/promptsource).

## Running Commands

You can run `python finetune.py --help` to see the usage of all the supported configurations. Using the default configuration as presented in the following command will reproduce the [reported results](#results).

```bash
python finetune.py
```

## Results

Using the above commands allows us to use the model version from best performing epoch on the validation set to test the performance on the test set. Accordingly, accuracy for ```glm-roberta-large``` on the ```emotion``` dataset increased from **25.85%** before fine-tuning to **93.35%** after fine-tuning, while the respective performance on the validation set was **94.45%**.

## Reference

```bibtex
@inproceedings{saravia-etal-2018-carer,
title = "{CARER}: Contextualized Affect Representations for Emotion Recognition",
author = "Saravia, Elvis and
Liu, Hsien-Chi Toby and
Huang, Yen-Hao and
Wu, Junlin and
Chen, Yi-Shin",
booktitle = "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing",
month = oct # "-" # nov,
year = "2018",
address = "Brussels, Belgium",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/D18-1404",
doi = "10.18653/v1/D18-1404",
pages = "3687--3697",
abstract = "Emotions are expressed in nuanced ways, which varies by collective or individual experiences, knowledge, and beliefs. Therefore, to understand emotion, as conveyed through text, a robust mechanism capable of capturing and modeling different linguistic nuances and phenomena is needed. We propose a semi-supervised, graph-based algorithm to produce rich structural descriptors which serve as the building blocks for constructing contextualized affect representations from text. The pattern-based representations are further enriched with word embeddings and evaluated through several emotion recognition tasks. Our experimental results demonstrate that the proposed method outperforms state-of-the-art techniques on emotion recognition tasks.",
}
```
35 changes: 35 additions & 0 deletions examples/emotion/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
from tqdm import tqdm
from datasets import load_dataset
from promptsource.templates import DatasetTemplates


class MultipleChoiceDataset(torch.utils.data.Dataset):
def __init__(self, dataset_name, split, prompt_name, tokenizer):
super(MultipleChoiceDataset, self).__init__()
self.dataset_name = dataset_name
self.split = split
self.prompt = DatasetTemplates(self.dataset_name)[prompt_name]
self.tokenizer = tokenizer

self.data = []
if '/' in self.dataset_name:
iters = load_dataset(self.dataset_name.split('/')[0], self.dataset_name.split('/')[1], split=self.split)
else:
iters = load_dataset(self.dataset_name, split=self.split)
for sample in tqdm(iters):
self.data.append(dict(zip(
['inputs_pretokenized', 'choices_pretokenized', 'label'],
self.prompting_single_sample(sample)
)))

def prompting_single_sample(self, sample):
inputs_pretokenized, _ = tuple(self.prompt.apply(sample))
choices_pretokenized = self.prompt.answer_choices.split(' ||| ')
return inputs_pretokenized + f" {self.tokenizer.mask_token}", choices_pretokenized, sample['label']

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return self.data[index]
30 changes: 30 additions & 0 deletions examples/emotion/eval_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from multiple_choice_utils import cond_log_prob, flatten_labels


def evaluate(model, tokenizer, data_loader, split):
valid_loss = 0.
valid_labels = []
valid_preds = []

model.eval()

with torch.no_grad():
for _, sample in tqdm(enumerate(data_loader, start=1), desc=split, total=len(data_loader)):
logits = cond_log_prob(model, tokenizer, sample["inputs_pretokenized"], flatten_labels(sample['choices_pretokenized']))

labels = sample["label"].cuda()
loss = F.nll_loss(logits, labels)
valid_loss += loss.item()
valid_preds.extend(torch.argmax(logits, dim=-1).cpu().numpy().tolist())
valid_labels.extend(np.array(sample["label"]).tolist())

valid_loss = valid_loss / len(data_loader)
valid_acc = accuracy_score(valid_preds, valid_labels)
print(f"[{split.upper()}] loss={valid_loss}, acc={valid_acc}")

return valid_loss, valid_acc
61 changes: 61 additions & 0 deletions examples/emotion/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch
import argparse
import warnings
from train_utils import train
from eval_utils import evaluate
from dataset import MultipleChoiceDataset
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig, get_linear_schedule_with_warmup


def main():
parser = argparse.ArgumentParser()
parser.add_argument('-mt', '--model_type', type=str, default='BAAI/glm-roberta-large')
parser.add_argument('-dn', '--dataset_name', type=str, default='emotion')
parser.add_argument('-pn', '--prompt_name', type=str, default='select_emotion_label_from_list')
parser.add_argument('-bs', '--batch_size', type=int, default=16)
parser.add_argument('-lr', '--learning_rate', type=float, default=1e-5)
parser.add_argument('-en', '--epoch_num', type=int, default=10)
parser.add_argument('-es', '--early_stopping', type=int, default=2)
parser.add_argument('-cd', '--ckpt_dir', type=str, default='./')
args = parser.parse_args()
print(args)

# Load model
tokenizer = AutoTokenizer.from_pretrained(args.model_type, trust_remote_code=True, revision='main')
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_type, trust_remote_code=True, revision='main').cuda()

# Load data
train_dataset = MultipleChoiceDataset(args.dataset_name, 'train', args.prompt_name, tokenizer)
valid_dataset = MultipleChoiceDataset(args.dataset_name, 'validation', args.prompt_name, tokenizer)
test_dataset = MultipleChoiceDataset(args.dataset_name, 'test', args.prompt_name, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

# Configure training model, optimizer, and scheduler
model = model.float()
model.train()
num_training_steps = args.epoch_num * (len(train_dataset) // args.batch_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
scheduler = get_linear_schedule_with_warmup(optimizer,
num_warmup_steps=int(num_training_steps * 0.06),
num_training_steps=num_training_steps)

print('Performance on test set BEFORE fine-tuning:')
evaluate(model, tokenizer, test_loader, 'test')

print('TRAINING...')
ckpt_path = args.ckpt_dir + \
f"{args.model_type.split('/')[1] if '/' in args.model_type else args.model_type}-" + \
f"{args.dataset_name.split('/')[1] if '/' in args.dataset_name else args.dataset_name}.ckpt"
model = train(model, tokenizer, train_loader, valid_loader, optimizer, scheduler, ckpt_path,
args.epoch_num, args.early_stopping)

print('Performance on test set AFTER fine-tuning:')
evaluate(model, tokenizer, test_loader, 'test')

if __name__ == '__main__':
warnings.filterwarnings('ignore')
main()
121 changes: 121 additions & 0 deletions examples/emotion/multiple_choice_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
'''
Acknowledgement: Code adapted from Aohan Zeng and Xiao Liu.
'''

import torch
import numpy as np
import torch.nn.functional as F
from typing import List
from scipy.linalg import block_diag


def flatten_labels(compacted_labels):
batch_size = len(compacted_labels[0])
num_of_classes = len(compacted_labels)
return [[compacted_labels[i][idx] for i in range(num_of_classes)] for idx in range(batch_size)]


def build_multiple_choice_sample(tokenizer, context, choices):
context_id = tokenizer(context)['input_ids']

division = len(context_id)
mask_position = context_id.index(tokenizer.mask_token_id)

token = np.array(context_id, dtype=np.int64)
attention_mask = [np.ones((division, division), dtype=np.int64)]
position_id = np.arange(division, dtype=np.int64)
block_position_id = np.zeros(division, dtype=np.int64)

choice_target_id = []
choice_id = []

for choice_str in choices:
choice = np.array(tokenizer(choice_str)['input_ids'][1:-1], dtype=np.int64)

choice_id.append(choice)
choice_target_id.append(np.arange(len(token), len(token) + len(choice), dtype=np.int64))
attention_mask.append(np.tril(np.ones((len(choice), len(choice)), dtype=np.int64)))

token = np.concatenate((token, [tokenizer.sop_token_id], choice[:-1]))
position_id = np.concatenate((position_id, [mask_position] * len(choice)))
block_position_id = np.concatenate((block_position_id, np.arange(1, 1 + len(choice), dtype=np.int64)))

attention_mask = block_diag(*attention_mask)
attention_mask[division:, :division] = 1

return {
"token": token,
"position_id": np.stack((position_id, block_position_id)),
"attention_mask": attention_mask,
"choices": choice_id,
"choice_target_ids": choice_target_id
}


def pad_batch(tokens, position_ids, attention_mask, max_seq_length):
pad_length = max_seq_length - len(tokens)
attention_mask = np.pad(
attention_mask,
pad_width=((0, pad_length),),
mode="constant",
constant_values=0,
)
tokens = np.concatenate((tokens, np.zeros(pad_length, dtype=np.int64)))
position_ids = np.concatenate((position_ids, position_ids[..., -1:].repeat(pad_length, -1)), axis=-1)
return tokens, position_ids, attention_mask


def collate_fn(samples):
TILE = 16
length_to_pad = (max(map(lambda spl: len(spl["token"]), samples)) + TILE - 1) // TILE * TILE

token_batch, position_id_batch, attention_mask_batch = [], [], []
choices_batch, choice_target_ids_batch = [], []

for sample in samples:
token, position_id, attention_mask = pad_batch(
sample["token"], sample["position_id"], sample["attention_mask"], length_to_pad
)
token_batch.append(token)
position_id_batch.append(position_id)
attention_mask_batch.append(attention_mask)
choices_batch.append(sample["choices"])
choice_target_ids_batch.append(sample["choice_target_ids"])

return {
"tokens": torch.tensor(np.array(token_batch), dtype=torch.int64),
"position_ids": torch.tensor(np.array(position_id_batch), dtype=torch.int64),
"attention_mask": torch.tensor(np.array(attention_mask_batch), dtype=torch.int64),
"choices": choices_batch,
"choice_target_ids": choice_target_ids_batch,
}


def cond_log_prob(model, tokenizer, context: List[str], choices: List[List[str]]) -> List[List[float]]:
"""
Compute conditonal probability for one or more continuation/infilling options.
:return The log probablity of each option.
"""
if not isinstance(context, list):
context = [context]
choices = [choices]
choices = [[(' ' + choice) for choice in choice_pair] for choice_pair in choices] # Feature of SentencePiece tokenizer

samples = [build_multiple_choice_sample(tokenizer, ctx, ch) for ctx, ch in zip(context, choices)]

batch = collate_fn(samples)

logits = model.forward(input_ids=batch['tokens'].cuda(),
attention_mask=batch['attention_mask'].cuda().unsqueeze(1),
position_ids=batch['position_ids'].cuda())['logits']

log_probs = []

for output, choices, choice_target_ids in zip(F.log_softmax(logits, dim=-1), batch['choices'], batch['choice_target_ids']):
log_probs_single = []
for choice, choice_target_id in zip(choices, choice_target_ids):
tmp = output[choice_target_id, choice]
log_probs_single.append(tmp.sum())
log_probs.append(torch.stack(log_probs_single))

return torch.stack(log_probs)
7 changes: 7 additions & 0 deletions examples/emotion/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
transformers
scipy
datasets
promptsource
scikit_learn
sentencepiece
tqdm
53 changes: 53 additions & 0 deletions examples/emotion/train_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
from eval_utils import evaluate
from multiple_choice_utils import cond_log_prob, flatten_labels


def train(model, tokenizer, train_loader, valid_loader, optimizer, scheduler, ckpt_path, epoch_num, early_stopping=-1):

best_acc = 0.
early_stopping_counter = early_stopping

for e in range(1, epoch_num + 1):
print(f"EPOCH {e}")
train_loss_value = 0.
tqdm_vars = {"lr": np.nan, "loss": np.nan}
tbar = tqdm(enumerate(train_loader, start=1), desc="train", total=len(train_loader),
postfix=tqdm_vars)

model.train()

for _, sample in tbar:
logits = cond_log_prob(model, tokenizer, sample["inputs_pretokenized"], flatten_labels(sample['choices_pretokenized']))
labels = sample["label"].cuda()
loss = F.nll_loss(logits, labels)
train_loss_value += loss.item()

loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()

tqdm_vars["lr"] = optimizer.state_dict()["param_groups"][0]["lr"]
tqdm_vars["loss"] = train_loss_value
tbar.set_postfix(tqdm_vars)
train_loss_value = 0.

_, valid_acc = evaluate(model, tokenizer, valid_loader, 'valid')

if early_stopping >= 0:
if valid_acc > best_acc:
best_acc = valid_acc
early_stopping_counter = early_stopping
torch.save(model, ckpt_path)
else:
early_stopping_counter -= 1

if early_stopping_counter <= 0:
print('EARLY STOPPING...')
break

return torch.load(ckpt_path)
Loading