Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions nlpaug/augmenter/word/context_word_embs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@ def init_context_word_embs_model(model_path, model_type, device, force_reload=Fa
model = nml.Roberta(model_path, device=device, top_k=top_k, silence=silence, batch_size=batch_size)
elif model_type == 'bert':
model = nml.Bert(model_path, device=device, top_k=top_k, silence=silence, batch_size=batch_size)
elif model_type == 'electra':
model = nml.Electra(model_path, device=device, top_k=top_k, silence=silence, batch_size=batch_size)
else:
raise ValueError('Model type value is unexpected. Only support bert and roberta models.')
raise ValueError('Model type value is unexpected. Only support bert, electra, and roberta models.')
else:
if model_type in ['distilbert', 'bert', 'roberta', 'bart']:
model = nml.FmTransformers(model_path, model_type=model_type, device=device, batch_size=batch_size,
top_k=top_k, silence=silence)
else:
raise ValueError('Model type value is unexpected. Only support bert and roberta models.')
raise ValueError('Model type value is unexpected. Only support bert, electra, and roberta models.')

CONTEXT_WORD_EMBS_MODELS[model_name] = model
return model
Expand Down
3 changes: 2 additions & 1 deletion nlpaug/model/lang_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from nlpaug.model.lang_models.fairseq import *
from nlpaug.model.lang_models.t5 import *
from nlpaug.model.lang_models.bart import *
from nlpaug.model.lang_models.electra import *
from nlpaug.model.lang_models.fill_mask_transformers import *
from nlpaug.model.lang_models.machine_translation_transformers import *
from nlpaug.model.lang_models.summarization_transformers import *
from nlpaug.model.lang_models.lambada import *
from nlpaug.model.lang_models.text_generation_transformers import *
from nlpaug.model.lang_models.text_generation_transformers import *
121 changes: 121 additions & 0 deletions nlpaug/model/lang_models/electra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@

import logging

try:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
except ImportError:
# No installation required if not using this function
pass

from nlpaug.model.lang_models import LanguageModels
from nlpaug.util.selection.filtering import *


class Electra(LanguageModels):
# https://arxiv.org/pdf/1810.04805.pdf
START_TOKEN = '[CLS]'
SEPARATOR_TOKEN = '[SEP]'
MASK_TOKEN = '[MASK]'
PAD_TOKEN = '[PAD]'
UNKNOWN_TOKEN = '[UNK]'
SUBWORD_PREFIX = '##'

def __init__(self, model_path='google/electra-small-discriminator', temperature=1.0, top_k=None, top_p=None, batch_size=32,
device='cuda', silence=True):
super().__init__(device, temperature=temperature, top_k=top_k, top_p=top_p, batch_size=batch_size, silence=silence)
try:
from transformers import AutoModelForMaskedLM, AutoTokenizer
except ModuleNotFoundError:
raise ModuleNotFoundError('Missed transformers library. Install transfomers by `pip install transformers`')

self.model_path = model_path

self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.mask_id = self.token2id(self.MASK_TOKEN)
self.pad_id = self.token2id(self.PAD_TOKEN)
if silence:
# Transformers thrown an warning regrading to weight initialization. It is expected
orig_log_level = logging.getLogger('transformers.' + 'modeling_utils').getEffectiveLevel()
logging.getLogger('transformers.' + 'modeling_utils').setLevel(logging.ERROR)
self.model = AutoModelForMaskedLM.from_pretrained(model_path)
logging.getLogger('transformers.' + 'modeling_utils').setLevel(orig_log_level)
else:
self.model = AutoModelForMaskedLM.from_pretrained(model_path)

self.model.to(self.device)
self.model.eval()

def get_max_num_token(self):
return self.model.config.max_position_embeddings - 2 * 5

def is_skip_candidate(self, candidate):
return candidate.startswith(self.SUBWORD_PREFIX)

def token2id(self, token):
# Iseue 181: TokenizerFast have convert_tokens_to_ids but not convert_tokens_to_id
if 'TokenizerFast' in self.tokenizer.__class__.__name__:
# New transformers API
return self.tokenizer.convert_tokens_to_ids(token)
else:
# Old transformers API
return self.tokenizer._convert_token_to_id(token)

def id2token(self, _id):
return self.tokenizer._convert_id_to_token(_id)

def get_model(self):
return self.model

def get_tokenizer(self):
return self.tokenizer

def get_subword_prefix(self):
return self.SUBWORD_PREFIX

def get_mask_token(self):
return self.MASK_TOKEN

def predict(self, texts, target_words=None, n=1):
results = []
# Prepare inputs
for i in range(0, len(texts), self.batch_size):
token_inputs = [self.tokenizer.encode(text) for text in texts[i:i+self.batch_size]]
if target_words is None:
target_words = [None] * len(token_inputs)

# Pad token
max_token_size = max([len(t) for t in token_inputs])
for i, token_input in enumerate(token_inputs):
for _ in range(max_token_size - len(token_input)):
token_inputs[i].append(self.pad_id)

target_poses = []
for tokens in token_inputs:
target_poses.append(tokens.index(self.mask_id))
# segment_inputs = [[0] * len(tokens) for tokens in token_inputs]
mask_inputs = [[1] * len(tokens) for tokens in token_inputs] # 1: real token, 0: padding token

# Convert to feature
token_inputs = torch.tensor(token_inputs).to(self.device)
# segment_inputs = torch.tensor(segment_inputs).to(self.device)
mask_inputs = torch.tensor(mask_inputs).to(self.device)

# Prediction
with torch.no_grad():
outputs = self.model(input_ids=token_inputs, attention_mask=mask_inputs)

# Selection
for output, target_pos, target_token in zip(outputs[0], target_poses, target_words):
target_token_logits = output[target_pos]

seed = {'temperature': self.temperature, 'top_k': self.top_k, 'top_p': self.top_p}
target_token_logits = self.control_randomness(target_token_logits, seed)
target_token_logits, target_token_idxes = self.filtering(target_token_logits, seed)
if len(target_token_idxes) != 0:
new_tokens = self.pick(target_token_logits, target_token_idxes, target_word=target_token, n=10)
results.append([t[0] for t in new_tokens])
else:
results.append([''])

return results