diff --git a/nlpaug/augmenter/word/context_word_embs.py b/nlpaug/augmenter/word/context_word_embs.py index 1a4e9fd..84afd75 100755 --- a/nlpaug/augmenter/word/context_word_embs.py +++ b/nlpaug/augmenter/word/context_word_embs.py @@ -32,14 +32,16 @@ def init_context_word_embs_model(model_path, model_type, device, force_reload=Fa model = nml.Roberta(model_path, device=device, top_k=top_k, silence=silence, batch_size=batch_size) elif model_type == 'bert': model = nml.Bert(model_path, device=device, top_k=top_k, silence=silence, batch_size=batch_size) + elif model_type == 'electra': + model = nml.Electra(model_path, device=device, top_k=top_k, silence=silence, batch_size=batch_size) else: - raise ValueError('Model type value is unexpected. Only support bert and roberta models.') + raise ValueError('Model type value is unexpected. Only support bert, electra, and roberta models.') else: if model_type in ['distilbert', 'bert', 'roberta', 'bart']: model = nml.FmTransformers(model_path, model_type=model_type, device=device, batch_size=batch_size, top_k=top_k, silence=silence) else: - raise ValueError('Model type value is unexpected. Only support bert and roberta models.') + raise ValueError('Model type value is unexpected. Only support bert, electra, and roberta models.') CONTEXT_WORD_EMBS_MODELS[model_name] = model return model diff --git a/nlpaug/model/lang_models/__init__.py b/nlpaug/model/lang_models/__init__.py index d283177..9c07662 100755 --- a/nlpaug/model/lang_models/__init__.py +++ b/nlpaug/model/lang_models/__init__.py @@ -8,8 +8,9 @@ from nlpaug.model.lang_models.fairseq import * from nlpaug.model.lang_models.t5 import * from nlpaug.model.lang_models.bart import * +from nlpaug.model.lang_models.electra import * from nlpaug.model.lang_models.fill_mask_transformers import * from nlpaug.model.lang_models.machine_translation_transformers import * from nlpaug.model.lang_models.summarization_transformers import * from nlpaug.model.lang_models.lambada import * -from nlpaug.model.lang_models.text_generation_transformers import * \ No newline at end of file +from nlpaug.model.lang_models.text_generation_transformers import * diff --git a/nlpaug/model/lang_models/electra.py b/nlpaug/model/lang_models/electra.py new file mode 100644 index 0000000..8166e30 --- /dev/null +++ b/nlpaug/model/lang_models/electra.py @@ -0,0 +1,121 @@ + +import logging + +try: + import torch + from transformers import AutoModelForMaskedLM, AutoTokenizer +except ImportError: + # No installation required if not using this function + pass + +from nlpaug.model.lang_models import LanguageModels +from nlpaug.util.selection.filtering import * + + +class Electra(LanguageModels): + # https://arxiv.org/pdf/1810.04805.pdf + START_TOKEN = '[CLS]' + SEPARATOR_TOKEN = '[SEP]' + MASK_TOKEN = '[MASK]' + PAD_TOKEN = '[PAD]' + UNKNOWN_TOKEN = '[UNK]' + SUBWORD_PREFIX = '##' + + def __init__(self, model_path='google/electra-small-discriminator', temperature=1.0, top_k=None, top_p=None, batch_size=32, + device='cuda', silence=True): + super().__init__(device, temperature=temperature, top_k=top_k, top_p=top_p, batch_size=batch_size, silence=silence) + try: + from transformers import AutoModelForMaskedLM, AutoTokenizer + except ModuleNotFoundError: + raise ModuleNotFoundError('Missed transformers library. Install transfomers by `pip install transformers`') + + self.model_path = model_path + + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.mask_id = self.token2id(self.MASK_TOKEN) + self.pad_id = self.token2id(self.PAD_TOKEN) + if silence: + # Transformers thrown an warning regrading to weight initialization. It is expected + orig_log_level = logging.getLogger('transformers.' + 'modeling_utils').getEffectiveLevel() + logging.getLogger('transformers.' + 'modeling_utils').setLevel(logging.ERROR) + self.model = AutoModelForMaskedLM.from_pretrained(model_path) + logging.getLogger('transformers.' + 'modeling_utils').setLevel(orig_log_level) + else: + self.model = AutoModelForMaskedLM.from_pretrained(model_path) + + self.model.to(self.device) + self.model.eval() + + def get_max_num_token(self): + return self.model.config.max_position_embeddings - 2 * 5 + + def is_skip_candidate(self, candidate): + return candidate.startswith(self.SUBWORD_PREFIX) + + def token2id(self, token): + # Iseue 181: TokenizerFast have convert_tokens_to_ids but not convert_tokens_to_id + if 'TokenizerFast' in self.tokenizer.__class__.__name__: + # New transformers API + return self.tokenizer.convert_tokens_to_ids(token) + else: + # Old transformers API + return self.tokenizer._convert_token_to_id(token) + + def id2token(self, _id): + return self.tokenizer._convert_id_to_token(_id) + + def get_model(self): + return self.model + + def get_tokenizer(self): + return self.tokenizer + + def get_subword_prefix(self): + return self.SUBWORD_PREFIX + + def get_mask_token(self): + return self.MASK_TOKEN + + def predict(self, texts, target_words=None, n=1): + results = [] + # Prepare inputs + for i in range(0, len(texts), self.batch_size): + token_inputs = [self.tokenizer.encode(text) for text in texts[i:i+self.batch_size]] + if target_words is None: + target_words = [None] * len(token_inputs) + + # Pad token + max_token_size = max([len(t) for t in token_inputs]) + for i, token_input in enumerate(token_inputs): + for _ in range(max_token_size - len(token_input)): + token_inputs[i].append(self.pad_id) + + target_poses = [] + for tokens in token_inputs: + target_poses.append(tokens.index(self.mask_id)) + # segment_inputs = [[0] * len(tokens) for tokens in token_inputs] + mask_inputs = [[1] * len(tokens) for tokens in token_inputs] # 1: real token, 0: padding token + + # Convert to feature + token_inputs = torch.tensor(token_inputs).to(self.device) + # segment_inputs = torch.tensor(segment_inputs).to(self.device) + mask_inputs = torch.tensor(mask_inputs).to(self.device) + + # Prediction + with torch.no_grad(): + outputs = self.model(input_ids=token_inputs, attention_mask=mask_inputs) + + # Selection + for output, target_pos, target_token in zip(outputs[0], target_poses, target_words): + target_token_logits = output[target_pos] + + seed = {'temperature': self.temperature, 'top_k': self.top_k, 'top_p': self.top_p} + target_token_logits = self.control_randomness(target_token_logits, seed) + target_token_logits, target_token_idxes = self.filtering(target_token_logits, seed) + if len(target_token_idxes) != 0: + new_tokens = self.pick(target_token_logits, target_token_idxes, target_word=target_token, n=10) + results.append([t[0] for t in new_tokens]) + else: + results.append(['']) + + return results \ No newline at end of file