diff --git a/examples/sequence-to-sequence/seq2seq_translator/README.md b/examples/sequence-to-sequence/seq2seq_translator/README.md new file mode 100644 index 000000000..f641cfbd8 --- /dev/null +++ b/examples/sequence-to-sequence/seq2seq_translator/README.md @@ -0,0 +1,151 @@ +# Seq2seq Translator Benchmarks + +Here is the comparison between Dynet and PyTorch on the [seq2seq translator example](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html). + +The data we used is a set of many thousands of English to French translation pairs. Download the data from [here](https://download.pytorch.org/tutorial/data.zip) and extract it to the current directory. + +## Usage (DyNet) + +The architecture of the Dynet model `seq2seq_dynet.py` is the same as that in [PyTorch Example](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html). We here implement the attention mechanism in the model. + +Install the GPU version of Dynet according to the instructions on the [official website](http://dynet.readthedocs.io/en/latest/python.html#installing-a-cutting-edge-and-or-gpu-version). + +Then, run the training: + + python seq2seq_dynet.py --dynet_gpus 1 + +## Usage (PyTorch) + +The code of `seq2seq_pytorch.py` follows the same line in [PyTorch Example](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html). + +Install CUDA version of PyTorch according to the instructions on the [official website](http://pytorch.org/). + +Then, run the training: + + python seq2seq_pytorch.py + +## Performance + +We run our codes on a desktop with NVIDIA TITAN X. We here have D stands for Dynet and P stands for PyTorch. + +| Time (D) | Iteration (D) | Loss (D) | Time (P) | Iteration (P) | Loss (P)| +| --- | --- | --- | --- | --- | --- | +| 0m 0s | 0% | 7.9808 | 0m 0s | 0% | 7.9615 | +| 0m 28s | 5000 5% | 3.2687 | 1m 30s | 5000 5% | 2.8794 | +| 0m 56s | 10000 10% | 2.6397 | 2m 55s | 10000 10% | 2.3103 | +| 1m 25s | 15000 15% | 2.3537 | 4m 5s | 15000 15% | 1.9939 | +| 1m 54s | 20000 20% | 2.1538 | 5m 16s | 20000 20% | 1.7537 | +| 2m 22s | 25000 25% | 1.9636 | 6m 27s | 25000 25% | 1.5796 | +| 2m 51s | 30000 30% | 1.8166 | 7m 39s | 30000 30% | 1.3795 | +| 3m 20s | 35000 35% | 1.6305 | 9m 13s | 35000 35% | 1.2712 | +| 3m 49s | 40000 40% | 1.5026 | 10m 31s | 40000 40% | 1.1374 | +| 4m 18s | 45000 45% | 1.4049 | 11m 41s | 45000 45% | 1.0215 | +| 4m 47s | 50000 50% | 1.2827 | 12m 53s | 50000 50% | 0.9307 | +| 5m 17s | 55000 55% | 1.2299 | 14m 5s | 55000 55% | 0.8312 | +| 5m 46s | 60000 60% | 1.1067 | 15m 17s | 60000 60% | 0.7879 | +| 6m 15s | 65000 65% | 1.0442 | 16m 48s | 65000 65% | 0.7188 | +| 6m 44s | 70000 70% | 0.9789 | 18m 6s | 70000 70% | 0.6532 | +| 7m 13s | 75000 75% | 0.8694 | 19m 18s | 75000 75% | 0.6273 | +| 7m 43s | 80000 80% | 0.8219 | 20m 34s | 80000 80% | 0.6021 | +| 8m 12s | 85000 85% | 0.7621 | 21m 44s | 85000 85% | 0.5210 | +| 8m 41s | 90000 90% | 0.7453 | 22m 55s | 90000 90% | 0.5054 | +| 9m 10s | 95000 95% | 0.6795 | 24m 9s | 95000 95% | 0.4417 | +| 9m 39s | 100000 100% | 0.6442 | 25m 24s | 100000 100% | 0.4297 | + +We then show some evaluation results as follows. + +Format: + +
+> input 
+= target 
+< output
+
+ +### Dynet + +``` +> elle est convaincue de mon innocence . += she is convinced of my innocence . +< she is convinced of my innocence . + +> je ne suis pas folle . += i m not crazy . +< i m not mad . + +> je suis ruinee . += i m ruined . +< i m ruined . + +> je ne suis certainement pas ton ami . += i m certainly not your friend . +< i m not your best your friend . + +> c est un pleurnichard comme toujours . += he s a crybaby just like always . +< he s a little nothing . + +> je suis sure qu elle partira tot . += i m sure she ll leave early . +< i m sure she ll leave early . + +> vous etes toujours vivantes . += you re still alive . +< you re still alive . + +> nous n avons pas encore tres faim . += we aren t very hungry yet . +< we re not not desperate . + +> vous n etes pas encore morts . += you re not dead yet . +< you re not dead yet . + +> nous sommes coinces . += we re stuck . +< we re stuck . +``` + +### PyTorch + +``` +> il est deja marie . += he s already married . +< he s already married . + +> on le dit decede . += he is said to have died . +< he are said to have died . + +> il est trop saoul . += he s too drunk . +< he s too drunk . + +> je suis assez heureux . += i m happy enough . +< i m happy happy . + +> je n y suis pas interessee . += i m not interested in that . +< i m not interested in that . + +> il a huit ans . += he s eight years old . +< he is thirty . + +> je ne suis pas differente . += i m no different . +< i m no different . + +> je suis heureux que vous l ayez aime . += i m happy you liked it . +< i m happy you liked it . + +> ils peuvent chanter . += they re able to sing . +< they re able to sing . + +> vous etes tellement belle dans cette robe ! += you re so beautiful in that dress . +< you re so beautiful in that dress . +``` diff --git a/examples/sequence-to-sequence/seq2seq_translator/seq2seq_dynet.py b/examples/sequence-to-sequence/seq2seq_translator/seq2seq_dynet.py new file mode 100644 index 000000000..fb62589ec --- /dev/null +++ b/examples/sequence-to-sequence/seq2seq_translator/seq2seq_dynet.py @@ -0,0 +1,367 @@ +# Requirements + +from __future__ import unicode_literals, print_function, division +import io +import unicodedata +import re +import random +import dynet as dy +import time +import math +r = random.SystemRandom() + +# Data Preparation + +SOS_token = 0 +EOS_token = 1 + + +class Lang(object): + + def __init__(self, name): + self.name = name + self.word2index = {} + self.word2count = {} + self.index2word = {0: "SOS", 1: "EOS"} + self.n_words = 2 + + def addSentence(self, sentence): + for word in sentence.split(" "): + self.addWord(word) + + def addWord(self, word): + if word not in self.word2index: + self.word2index[word] = self.n_words + self.word2count[word] = 1 + self.index2word[self.n_words] = word + self.n_words += 1 + else: + self.word2count[word] += 1 + + +def unicodeToAscii(s): + + return ''.join( + c for c in unicodedata.normalize('NFD', s) + if unicodedata.category(c) != 'Mn' + ) + + +def normalizeString(s): + + s = unicodeToAscii(s.lower().strip()) + s = re.sub(r"([.!?])", r" \1", s) + s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) + return s + + +def readLangs(lang1, lang2, reverse=False): + + print("Reading lines...") + lines = io.open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\ + read().strip().split('\n') + pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines] + if reverse: + pairs = [list(reversed(p)) for p in pairs] + input_lang = Lang(lang2) + output_lang = Lang(lang1) + else: + input_lang = Lang(lang1) + output_lang = Lang(lang2) + return input_lang, output_lang, pairs + + +MAX_LENGTH = 10 +eng_prefixes = ("i am ", "i m ", "he is", "he s ", "she is", "she s", + "you are", "you re ", "we are", "we re ", "they are", + "they re ") + + +def filterPair(p): + + return len(p[0].split(' ')) < MAX_LENGTH and \ + len(p[1].split(' ')) < MAX_LENGTH and \ + p[1].startswith(eng_prefixes) + + +def filterPairs(pairs): + + return [pair for pair in pairs if filterPair(pair)] + + +def prepareData(lang1, lang2, reverse=False): + + input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse) + print("Read %s sentence pairs" % len(pairs)) + pairs = filterPairs(pairs) + print("Trimmed to %s sentence pairs" % len(pairs)) + print("Counting words...") + for pair in pairs: + input_lang.addSentence(pair[0]) + output_lang.addSentence(pair[1]) + print("Counted words:") + print(input_lang.name, input_lang.n_words) + print(output_lang.name, output_lang.n_words) + return input_lang, output_lang, pairs + + +input_lang, output_lang, pairs = prepareData('eng', 'fra', True) +print(r.choice(pairs)) + +# Model + + +class EncoderRNN(object): + + def __init__(self, in_vocab, hidden_dim, model): + self.in_vocab = in_vocab + self.hidden_dim = hidden_dim + self.embedding_enc = model.add_lookup_parameters((self.in_vocab, + self.hidden_dim)) + self.rnn_enc = dy.GRUBuilder(1, self.hidden_dim, self.hidden_dim, + model) + + def __call__(self, inputs, hidden): + input_embed = dy.lookup(self.embedding_enc, inputs) + state_enc = self.rnn_enc.initial_state(vecs=hidden) + state_enc = state_enc.add_input(input_embed) + return state_enc.output(), state_enc.h() + + def initHidden(self): + return [dy.zeros(self.hidden_dim)] + + +DROPOUT_RATE = 0.1 + + +class AttnDecoderRNN(object): + + def __init__(self, hidden_dim, out_vocab, model, max_length=MAX_LENGTH): + self.hidden_dim = hidden_dim + self.out_vocab = out_vocab + self.max_length = max_length + self.embedding_dec = model.add_lookup_parameters((self.out_vocab, + self.hidden_dim)) + self.w_attn = model.add_parameters((self.max_length, + self.hidden_dim * 2)) + self.b_attn = model.add_parameters((self.max_length,)) + self.w_attn_combine = model.add_parameters((self.hidden_dim, + self.hidden_dim * 2)) + self.b_attn_combine = model.add_parameters((self.hidden_dim,)) + self.rnn_dec = dy.GRUBuilder(1, self.hidden_dim, self.hidden_dim, + model) + self.w_dec = model.add_parameters((self.out_vocab, self.hidden_dim)) + self.b_dec = model.add_parameters((self.out_vocab,)) + + def __call__(self, inputs, hidden, encoder_outptus, dropout=False): + input_embed = dy.lookup(self.embedding_dec, inputs) + if dropout: + input_embed = dy.dropout(input_embed, DROPOUT_RATE) + input_cat = dy.concatenate([input_embed, hidden[0]]) + w_attn = dy.parameter(self.w_attn) + b_attn = dy.parameter(self.b_attn) + attn_weights = dy.softmax(w_attn * input_cat + b_attn) + attn_applied = encoder_outptus * attn_weights + output = dy.concatenate([input_embed, attn_applied]) + w_attn_combine = dy.parameter(self.w_attn_combine) + b_attn_combine = dy.parameter(self.b_attn_combine) + output = w_attn_combine * output + b_attn_combine + output = dy.rectify(output) + state_dec = self.rnn_dec.initial_state(vecs=hidden) + state_dec = state_dec.add_input(output) + w_dec = dy.parameter(self.w_dec) + b_dec = dy.parameter(self.b_dec) + output = state_dec.output() + output = dy.softmax(w_dec * output + b_dec) + + return output, state_dec.h(), attn_weights + + def initHidden(self): + return [dy.zeros(self.hidden_dim)] + + +def indexesFromSentence(lang, sentence): + + return [lang.word2index[word] for word in sentence.split(" ")] + \ + [EOS_token] + + +def indexesFromPair(pair): + + input_indexes = indexesFromSentence(input_lang, pair[0]) + target_indexes = indexesFromSentence(output_lang, pair[1]) + return (input_indexes, target_indexes) + +# Training the Model + + +teacher_forcing_ratio = 0.5 + + +def train(inputs, targets, encoder, decoder, trainer, max_length=MAX_LENGTH): + + dy.renew_cg() + + encoder_hidden = encoder.initHidden() + + input_length = len(inputs) + target_length = len(targets) + + encoder_outputs = [dy.zeros(hidden_dim) for _ in range(max_length)] + + losses = [] + + for i in range(input_length): + encoder_output, encoder_hidden = encoder(inputs[i], encoder_hidden) + encoder_outputs[i] = encoder_output + + encoder_outputs = dy.concatenate(encoder_outputs, 1) + + decoder_input = SOS_token + decoder_hidden = encoder_hidden + + if r.random() < teacher_forcing_ratio: + use_teacher_forcing = True + else: + use_teacher_forcing = False + + if use_teacher_forcing: + for i in range(target_length): + decoder_output, decoder_hidden, _ = decoder( + decoder_input, decoder_hidden, encoder_outputs, dropout=True) + losses.append(-dy.log(dy.pick(decoder_output, targets[i]))) + decoder_input = targets[i] + else: + for i in range(target_length): + decoder_output, decoder_hidden, _ = decoder( + decoder_input, decoder_hidden, encoder_outputs, dropout=True) + losses.append(-dy.log(dy.pick(decoder_output, targets[i]))) + probs = decoder_output.vec_value() + decoder_input = probs.index(max(probs)) + if decoder_input == EOS_token: + break + + loss = dy.esum(losses)/len(losses) + loss.backward() + trainer.update() + + return loss.value() + +# Helper Function to Print Time + + +def asMinutes(s): + m = math.floor(s/60) + s -= m*60 + return "%dm %ds" % (m, s) + + +def timeSince(since, percent): + now = time.time() + s = now - since + es = s / (percent) + rs = es - s + return "%s (- %s)" % (asMinutes(s), asMinutes(rs)) + +# Whole Training Process + + +def trainIters(encoder, decoder, trainer, n_iters, print_every=1000, + plot_every=100): + + start = time.time() + plot_losses = [] + print_loss_total = 0 + plot_loss_total = 0 + + training_pairs = [indexesFromPair(r.choice(pairs)) + for _ in range(n_iters)] + + for iteration in range(1, n_iters+1): + + training_pair = training_pairs[iteration-1] + inputs = training_pair[0] + targets = training_pair[1] + + loss = train(inputs, targets, encoder, decoder, trainer) + + print_loss_total += loss + plot_loss_total += loss + + if iteration % print_every == 0: + print_loss_avg = print_loss_total/print_every + print_loss_total = 0 + print("%s (%d %d%%) %.4f" % (timeSince(start, iteration/n_iters), + iteration, iteration/n_iters*100, + print_loss_avg)) + + if iteration % plot_every == 0: + plot_loss_avg = plot_loss_total/plot_every + plot_losses.append(plot_loss_avg) + plot_loss_total = 0 + +# Evaluation + + +def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH): + + dy.renew_cg() + + encoder_hidden = encoder.initHidden() + + inputs = indexesFromSentence(input_lang, sentence) + input_length = len(inputs) + + encoder_outputs = [dy.zeros(hidden_dim) for _ in range(max_length)] + + for i in range(input_length): + encoder_output, encoder_hidden = encoder(inputs[i], encoder_hidden) + encoder_outputs[i] = encoder_output + + encoder_outputs = dy.concatenate(encoder_outputs, 1) + + decoder_input = SOS_token + decoder_hidden = encoder_hidden + + decoder_words = [] + decoder_attentions = [dy.zeros(max_length) for _ in range(max_length)] + + for i in range(max_length): + decoder_output, decoder_hidden, decoder_attention = decoder( + decoder_input, decoder_hidden, encoder_outputs, dropout=False) + decoder_attentions[i] = decoder_attention + probs = decoder_output.vec_value() + pred = probs.index(max(probs)) + if pred == EOS_token: + decoder_words.append("") + break + else: + decoder_words.append(output_lang.index2word[pred]) + decoder_input = pred + + return decoder_words + + +def evaluationRandomly(encoder, decoder, n=10): + + for _ in range(n): + pair = r.choice(pairs) + print(">", pair[0]) + print("=", pair[1]) + output_words = evaluate(encoder, decoder, pair[0]) + output_sentence = " ".join(output_words) + print("<", output_sentence) + print("") + +# Start Training and Evaluating + + +model = dy.ParameterCollection() +hidden_dim = 256 +encoder = EncoderRNN(input_lang.n_words, hidden_dim, model) +decoder = AttnDecoderRNN(hidden_dim, output_lang.n_words, model) +trainer = dy.SimpleSGDTrainer(model, learning_rate=0.2) + +trainIters(encoder, decoder, trainer, 100000, print_every=5000) + +evaluationRandomly(encoder, decoder) diff --git a/examples/sequence-to-sequence/seq2seq_translator/seq2seq_pytorch.py b/examples/sequence-to-sequence/seq2seq_translator/seq2seq_pytorch.py new file mode 100644 index 000000000..95aab4669 --- /dev/null +++ b/examples/sequence-to-sequence/seq2seq_translator/seq2seq_pytorch.py @@ -0,0 +1,378 @@ +# Requirements + +from __future__ import unicode_literals, print_function, division +import io +import unicodedata +import re +import random +import time +import math +import torch +import torch.nn as nn +from torch import optim +import torch.nn.functional as F +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +r = random.SystemRandom() + +# Data Preparation + +SOS_token = 0 +EOS_token = 1 + + +class Lang: + + def __init__(self, name): + self.name = name + self.word2index = {} + self.word2count = {} + self.index2word = {0: "SOS", 1: "EOS"} + self.n_words = 2 # Count SOS and EOS + + def addSentence(self, sentence): + for word in sentence.split(' '): + self.addWord(word) + + def addWord(self, word): + if word not in self.word2index: + self.word2index[word] = self.n_words + self.word2count[word] = 1 + self.index2word[self.n_words] = word + self.n_words += 1 + else: + self.word2count[word] += 1 + + +def unicodeToAscii(s): + + return ''.join( + c for c in unicodedata.normalize('NFD', s) + if unicodedata.category(c) != 'Mn' + ) + + +def normalizeString(s): + + s = unicodeToAscii(s.lower().strip()) + s = re.sub(r"([.!?])", r" \1", s) + s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) + return s + + +def readLangs(lang1, lang2, reverse=False): + + print("Reading lines...") + + lines = io.open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\ + read().strip().split('\n') + + pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines] + + if reverse: + pairs = [list(reversed(p)) for p in pairs] + input_lang = Lang(lang2) + output_lang = Lang(lang1) + else: + input_lang = Lang(lang1) + output_lang = Lang(lang2) + + return input_lang, output_lang, pairs + + +MAX_LENGTH = 10 + +eng_prefixes = ( + "i am ", "i m ", + "he is", "he s ", + "she is", "she s", + "you are", "you re ", + "we are", "we re ", + "they are", "they re " +) + + +def filterPair(p): + + return len(p[0].split(' ')) < MAX_LENGTH and \ + len(p[1].split(' ')) < MAX_LENGTH and \ + p[1].startswith(eng_prefixes) + + +def filterPairs(pairs): + + return [pair for pair in pairs if filterPair(pair)] + + +def prepareData(lang1, lang2, reverse=False): + + input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse) + print("Read %s sentence pairs" % len(pairs)) + pairs = filterPairs(pairs) + print("Trimmed to %s sentence pairs" % len(pairs)) + print("Counting words...") + for pair in pairs: + input_lang.addSentence(pair[0]) + output_lang.addSentence(pair[1]) + print("Counted words:") + print(input_lang.name, input_lang.n_words) + print(output_lang.name, output_lang.n_words) + return input_lang, output_lang, pairs + + +input_lang, output_lang, pairs = prepareData('eng', 'fra', True) +print(r.choice(pairs)) + +# Model + + +class EncoderRNN(nn.Module): + + def __init__(self, input_size, hidden_size): + super(EncoderRNN, self).__init__() + self.hidden_size = hidden_size + + self.embedding = nn.Embedding(input_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size) + + def forward(self, inputs, hidden): + embedded = self.embedding(inputs).view(1, 1, -1) + output = embedded + output, hidden = self.gru(output, hidden) + return output, hidden + + def initHidden(self): + return torch.zeros(1, 1, self.hidden_size, device=device) + + +class AttnDecoderRNN(nn.Module): + + def __init__(self, hidden_size, output_size, dropout_p=0.1, + max_length=MAX_LENGTH): + super(AttnDecoderRNN, self).__init__() + self.hidden_size = hidden_size + self.output_size = output_size + self.dropout_p = dropout_p + self.max_length = max_length + + self.embedding = nn.Embedding(self.output_size, self.hidden_size) + self.attn = nn.Linear(self.hidden_size * 2, self.max_length) + self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size) + self.dropout = nn.Dropout(self.dropout_p) + self.gru = nn.GRU(self.hidden_size, self.hidden_size) + self.out = nn.Linear(self.hidden_size, self.output_size) + + def forward(self, inputs, hidden, encoder_outputs): + embedded = self.embedding(inputs).view(1, 1, -1) + embedded = self.dropout(embedded) + + attn_weights = F.softmax( + self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1) + attn_applied = torch.bmm(attn_weights.unsqueeze(0), + encoder_outputs.unsqueeze(0)) + + output = torch.cat((embedded[0], attn_applied[0]), 1) + output = self.attn_combine(output).unsqueeze(0) + + output = F.relu(output) + output, hidden = self.gru(output, hidden) + + output = F.log_softmax(self.out(output[0]), dim=1) + return output, hidden, attn_weights + + def initHidden(self): + return torch.zeros(1, 1, self.hidden_size, device=device) + + +def indexesFromSentence(lang, sentence): + + return [lang.word2index[word] for word in sentence.split(' ')] + + +def tensorFromSentence(lang, sentence): + + indexes = indexesFromSentence(lang, sentence) + indexes.append(EOS_token) + return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1) + + +def tensorsFromPair(pair): + + input_tensor = tensorFromSentence(input_lang, pair[0]) + target_tensor = tensorFromSentence(output_lang, pair[1]) + return (input_tensor, target_tensor) + +# Training the Model + + +teacher_forcing_ratio = 0.5 + + +def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, + decoder_optimizer, criterion, max_length=MAX_LENGTH): + encoder_hidden = encoder.initHidden() + + encoder_optimizer.zero_grad() + decoder_optimizer.zero_grad() + + input_length = input_tensor.size(0) + target_length = target_tensor.size(0) + + encoder_outputs = torch.zeros(max_length, encoder.hidden_size, + device=device) + + loss = 0 + + for ei in range(input_length): + encoder_output, encoder_hidden = encoder( + input_tensor[ei], encoder_hidden) + encoder_outputs[ei] = encoder_output[0, 0] + + decoder_input = torch.tensor([[SOS_token]], device=device) + + decoder_hidden = encoder_hidden + + use_teacher_forcing = True if r.random() < teacher_forcing_ratio \ + else False + + if use_teacher_forcing: + + for di in range(target_length): + decoder_output, decoder_hidden, _ = decoder( + decoder_input, decoder_hidden, encoder_outputs) + loss += criterion(decoder_output, target_tensor[di]) + decoder_input = target_tensor[di] + + else: + + for di in range(target_length): + decoder_output, decoder_hidden, _ = decoder( + decoder_input, decoder_hidden, encoder_outputs) + _, topi = decoder_output.topk(1) + decoder_input = topi.squeeze().detach() + + loss += criterion(decoder_output, target_tensor[di]) + if decoder_input.item() == EOS_token: + break + + loss.backward() + + encoder_optimizer.step() + decoder_optimizer.step() + + return loss.item() / target_length + +# Helper Function to Print Time + + +def asMinutes(s): + m = math.floor(s / 60) + s -= m * 60 + return '%dm %ds' % (m, s) + + +def timeSince(since, percent): + now = time.time() + s = now - since + es = s / (percent) + rs = es - s + return '%s (- %s)' % (asMinutes(s), asMinutes(rs)) + +# Whole Training Process + + +def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, + learning_rate=0.01): + + start = time.time() + plot_losses = [] + print_loss_total = 0 + plot_loss_total = 0 + + encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate) + decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate) + training_pairs = [tensorsFromPair(r.choice(pairs)) + for _ in range(n_iters)] + criterion = nn.NLLLoss() + + for iteration in range(1, n_iters + 1): + training_pair = training_pairs[iteration - 1] + input_tensor = training_pair[0] + target_tensor = training_pair[1] + + loss = train(input_tensor, target_tensor, encoder, + decoder, encoder_optimizer, decoder_optimizer, criterion) + print_loss_total += loss + plot_loss_total += loss + + if iteration % print_every == 0: + print_loss_avg = print_loss_total / print_every + print_loss_total = 0 + print('%s (%d %d%%) %.4f' % (timeSince(start, iteration / n_iters), + iteration, iteration / n_iters * 100, + print_loss_avg)) + + if iteration % plot_every == 0: + plot_loss_avg = plot_loss_total / plot_every + plot_losses.append(plot_loss_avg) + plot_loss_total = 0 + + +def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH): + with torch.no_grad(): + input_tensor = tensorFromSentence(input_lang, sentence) + input_length = input_tensor.size()[0] + encoder_hidden = encoder.initHidden() + + encoder_outputs = torch.zeros(max_length, encoder.hidden_size, + device=device) + + for ei in range(input_length): + encoder_output, encoder_hidden = encoder(input_tensor[ei], + encoder_hidden) + encoder_outputs[ei] += encoder_output[0, 0] + + decoder_input = torch.tensor([[SOS_token]], device=device) # SOS + + decoder_hidden = encoder_hidden + + decoded_words = [] + decoder_attentions = torch.zeros(max_length, max_length) + + for di in range(max_length): + decoder_output, decoder_hidden, decoder_attention = decoder( + decoder_input, decoder_hidden, encoder_outputs) + decoder_attentions[di] = decoder_attention.data + _, topi = decoder_output.data.topk(1) + if topi.item() == EOS_token: + decoded_words.append('') + break + else: + decoded_words.append(output_lang.index2word[topi.item()]) + + decoder_input = topi.squeeze().detach() + + return decoded_words, decoder_attentions[:di + 1] + + +def evaluateRandomly(encoder, decoder, n=10): + + for _ in range(n): + pair = r.choice(pairs) + print('>', pair[0]) + print('=', pair[1]) + output_words, _ = evaluate(encoder, decoder, pair[0]) + output_sentence = ' '.join(output_words) + print('<', output_sentence) + print('') + +# Training and Evaluating + + +hidden_size = 256 +encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device) +attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, + dropout_p=0.1).to(device) + +trainIters(encoder1, attn_decoder1, 100000, print_every=5000) + +evaluateRandomly(encoder1, attn_decoder1)