Skip to content

Commit 057341f

Browse files
authored
Add OOV token handling to character-level RNN tutorial (#3284)
Improves char-rnn tutorial code quality ## Description This PR adds proper handling for Out-Of-Vocabulary (OOV) characters in the character-level RNN tutorial. Problem: - The current implementation doesn't properly handle characters not in the allowed set - Using string.find() returns -1 for unknown characters, causing them to be treated as apostrophes (the last character in the allowed_characters string) - This creates ambiguity between actual apostrophes in names (like O'Brien) and unknown characters Solution: - Added an underscore character "_" as a dedicated OOV token - Modified letterToIndex() to explicitly handle unknown characters - Added comments explaining the purpose of OOV handling - Updated the comment about input nodes (57 → 58) to reflect the added character This change follows best practices for NLP systems by explicitly handling unknown characters, improving both the model's accuracy and the tutorial's educational value.
1 parent ce291f4 commit 057341f

File tree

1 file changed

+57
-52
lines changed

1 file changed

+57
-52
lines changed

intermediate_source/char_rnn_classification_tutorial.py

+57-52
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
2626
Specifically, we'll train on a few thousand surnames from 18 languages
2727
of origin, and predict which language a name is from based on the
28-
spelling.
28+
spelling.
2929
3030
Recommended Preparation
3131
=======================
@@ -50,13 +50,13 @@
5050
general
5151
"""
5252
######################################################################
53-
# Preparing Torch
53+
# Preparing Torch
5454
# ==========================
5555
#
56-
# Set up torch to default to the right device use GPU acceleration depending on your hardware (CPU or CUDA).
56+
# Set up torch to default to the right device use GPU acceleration depending on your hardware (CPU or CUDA).
5757
#
5858

59-
import torch
59+
import torch
6060

6161
# Check if CUDA is available
6262
device = torch.device('cpu')
@@ -70,24 +70,25 @@
7070
# Preparing the Data
7171
# ==================
7272
#
73-
# Download the data from `here <https://download.pytorch.org/tutorial/data.zip>`__
73+
# Download the data from `here <https://download.pytorch.org/tutorial/data.zip>`__
7474
# and extract it to the current directory.
7575
#
7676
# Included in the ``data/names`` directory are 18 text files named as
7777
# ``[Language].txt``. Each file contains a bunch of names, one name per
7878
# line, mostly romanized (but we still need to convert from Unicode to
7979
# ASCII).
8080
#
81-
# The first step is to define and clean our data. Initially, we need to convert Unicode to plain ASCII to
82-
# limit the RNN input layers. This is accomplished by converting Unicode strings to ASCII and allowing only a small set of allowed characters.
81+
# The first step is to define and clean our data. Initially, we need to convert Unicode to plain ASCII to
82+
# limit the RNN input layers. This is accomplished by converting Unicode strings to ASCII and allowing only a small set of allowed characters.
8383

84-
import string
84+
import string
8585
import unicodedata
8686

87-
allowed_characters = string.ascii_letters + " .,;'"
88-
n_letters = len(allowed_characters)
87+
# We can use "_" to represent an out-of-vocabulary character, that is, any character we are not handling in our model
88+
allowed_characters = string.ascii_letters + " .,;'" + "_"
89+
n_letters = len(allowed_characters)
8990

90-
# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
91+
# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
9192
def unicodeToAscii(s):
9293
return ''.join(
9394
c for c in unicodedata.normalize('NFD', s)
@@ -120,7 +121,11 @@ def unicodeToAscii(s):
120121

121122
# Find letter index from all_letters, e.g. "a" = 0
122123
def letterToIndex(letter):
123-
return allowed_characters.find(letter)
124+
# return our out-of-vocabulary character if we encounter a letter unknown to our model
125+
if letter not in allowed_characters:
126+
return allowed_characters.find("_")
127+
else:
128+
return allowed_characters.find(letter)
124129

125130
# Turn a line into a <line_length x 1 x n_letters>,
126131
# or an array of one-hot letter vectors
@@ -137,16 +142,16 @@ def lineToTensor(line):
137142
print (f"The name 'Ahn' becomes {lineToTensor('Ahn')}") #notice 'A' sets the 27th index to 1
138143

139144
#########################
140-
# Congratulations, you have built the foundational tensor objects for this learning task! You can use a similar approach
145+
# Congratulations, you have built the foundational tensor objects for this learning task! You can use a similar approach
141146
# for other RNN tasks with text.
142147
#
143-
# Next, we need to combine all our examples into a dataset so we can train, test and validate our models. For this,
144-
# we will use the `Dataset and DataLoader <https://pytorch.org/tutorials/beginner/basics/data_tutorial.html>`__ classes
148+
# Next, we need to combine all our examples into a dataset so we can train, test and validate our models. For this,
149+
# we will use the `Dataset and DataLoader <https://pytorch.org/tutorials/beginner/basics/data_tutorial.html>`__ classes
145150
# to hold our dataset. Each Dataset needs to implement three functions: ``__init__``, ``__len__``, and ``__getitem__``.
146151
from io import open
147152
import glob
148153
import os
149-
import time
154+
import time
150155

151156
import torch
152157
from torch.utils.data import Dataset
@@ -155,26 +160,26 @@ class NamesDataset(Dataset):
155160

156161
def __init__(self, data_dir):
157162
self.data_dir = data_dir #for provenance of the dataset
158-
self.load_time = time.localtime #for provenance of the dataset
163+
self.load_time = time.localtime #for provenance of the dataset
159164
labels_set = set() #set of all classes
160165

161166
self.data = []
162167
self.data_tensors = []
163-
self.labels = []
164-
self.labels_tensors = []
168+
self.labels = []
169+
self.labels_tensors = []
165170

166171
#read all the ``.txt`` files in the specified directory
167-
text_files = glob.glob(os.path.join(data_dir, '*.txt'))
172+
text_files = glob.glob(os.path.join(data_dir, '*.txt'))
168173
for filename in text_files:
169174
label = os.path.splitext(os.path.basename(filename))[0]
170175
labels_set.add(label)
171176
lines = open(filename, encoding='utf-8').read().strip().split('\n')
172-
for name in lines:
177+
for name in lines:
173178
self.data.append(name)
174179
self.data_tensors.append(lineToTensor(name))
175180
self.labels.append(label)
176181

177-
#Cache the tensor representation of the labels
182+
#Cache the tensor representation of the labels
178183
self.labels_uniq = list(labels_set)
179184
for idx in range(len(self.labels)):
180185
temp_tensor = torch.tensor([self.labels_uniq.index(self.labels[idx])], dtype=torch.long)
@@ -187,7 +192,7 @@ def __getitem__(self, idx):
187192
data_item = self.data[idx]
188193
data_label = self.labels[idx]
189194
data_tensor = self.data_tensors[idx]
190-
label_tensor = self.labels_tensors[idx]
195+
label_tensor = self.labels_tensors[idx]
191196

192197
return label_tensor, data_tensor, data_label, data_item
193198

@@ -200,17 +205,17 @@ def __getitem__(self, idx):
200205
print(f"example = {alldata[0]}")
201206

202207
#########################
203-
#Using the dataset object allows us to easily split the data into train and test sets. Here we create a 80/20
204-
# split but the ``torch.utils.data`` has more useful utilities. Here we specify a generator since we need to use the
205-
#same device as PyTorch defaults to above.
208+
#Using the dataset object allows us to easily split the data into train and test sets. Here we create a 80/20
209+
# split but the ``torch.utils.data`` has more useful utilities. Here we specify a generator since we need to use the
210+
#same device as PyTorch defaults to above.
206211

207212
train_set, test_set = torch.utils.data.random_split(alldata, [.85, .15], generator=torch.Generator(device=device).manual_seed(2024))
208213

209214
print(f"train examples = {len(train_set)}, validation examples = {len(test_set)}")
210215

211216
#########################
212-
# Now we have a basic dataset containing **20074** examples where each example is a pairing of label and name. We have also
213-
#split the dataset into training and testing so we can validate the model that we build.
217+
# Now we have a basic dataset containing **20074** examples where each example is a pairing of label and name. We have also
218+
#split the dataset into training and testing so we can validate the model that we build.
214219

215220

216221
######################################################################
@@ -222,11 +227,11 @@ def __getitem__(self, idx):
222227
# held hidden state and gradients which are now entirely handled by the
223228
# graph itself. This means you can implement a RNN in a very "pure" way,
224229
# as regular feed-forward layers.
225-
#
226-
# This CharRNN class implements an RNN with three components.
230+
#
231+
# This CharRNN class implements an RNN with three components.
227232
# First, we use the `nn.RNN implementation <https://pytorch.org/docs/stable/generated/torch.nn.RNN.html>`__.
228233
# Next, we define a layer that maps the RNN hidden layers to our output. And finally, we apply a ``softmax`` function. Using ``nn.RNN``
229-
# leads to a significant improvement in performance, such as cuDNN-accelerated kernels, versus implementing
234+
# leads to a significant improvement in performance, such as cuDNN-accelerated kernels, versus implementing
230235
# each layer as a ``nn.Linear``. It also simplifies the implementation in ``forward()``.
231236
#
232237

@@ -240,7 +245,7 @@ def __init__(self, input_size, hidden_size, output_size):
240245
self.rnn = nn.RNN(input_size, hidden_size)
241246
self.h2o = nn.Linear(hidden_size, output_size)
242247
self.softmax = nn.LogSoftmax(dim=1)
243-
248+
244249
def forward(self, line_tensor):
245250
rnn_out, hidden = self.rnn(line_tensor)
246251
output = self.h2o(hidden[0])
@@ -250,14 +255,14 @@ def forward(self, line_tensor):
250255

251256

252257
###########################
253-
# We can then create an RNN with 57 input nodes, 128 hidden nodes, and 18 outputs:
258+
# We can then create an RNN with 58 input nodes, 128 hidden nodes, and 18 outputs:
254259

255260
n_hidden = 128
256261
rnn = CharRNN(n_letters, n_hidden, len(alldata.labels_uniq))
257-
print(rnn)
262+
print(rnn)
258263

259264
######################################################################
260-
# After that we can pass our Tensor to the RNN to obtain a predicted output. Subsequently,
265+
# After that we can pass our Tensor to the RNN to obtain a predicted output. Subsequently,
261266
# we use a helper function, ``label_from_output``, to derive a text label for the class.
262267

263268
def label_from_output(output, output_labels):
@@ -267,7 +272,7 @@ def label_from_output(output, output_labels):
267272

268273
input = lineToTensor('Albert')
269274
output = rnn(input) #this is equivalent to ``output = rnn.forward(input)``
270-
print(output)
275+
print(output)
271276
print(label_from_output(output, alldata.labels_uniq))
272277

273278
######################################################################
@@ -283,13 +288,13 @@ def label_from_output(output, output_labels):
283288
# Now all it takes to train this network is show it a bunch of examples,
284289
# have it make guesses, and tell it if it's wrong.
285290
#
286-
# We do this by defining a ``train()`` function which trains the model on a given dataset using minibatches. RNNs
291+
# We do this by defining a ``train()`` function which trains the model on a given dataset using minibatches. RNNs
287292
# RNNs are trained similarly to other networks; therefore, for completeness, we include a batched training method here.
288-
# The loop (``for i in batch``) computes the losses for each of the items in the batch before adjusting the
289-
# weights. This operation is repeated until the number of epochs is reached.
293+
# The loop (``for i in batch``) computes the losses for each of the items in the batch before adjusting the
294+
# weights. This operation is repeated until the number of epochs is reached.
290295

291-
import random
292-
import numpy as np
296+
import random
297+
import numpy as np
293298

294299
def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50, learning_rate = 0.2, criterion = nn.NLLLoss()):
295300
"""
@@ -298,22 +303,22 @@ def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50
298303
# Keep track of losses for plotting
299304
current_loss = 0
300305
all_losses = []
301-
rnn.train()
306+
rnn.train()
302307
optimizer = torch.optim.SGD(rnn.parameters(), lr=learning_rate)
303308

304309
start = time.time()
305310
print(f"training on data set with n = {len(training_data)}")
306311

307-
for iter in range(1, n_epoch + 1):
308-
rnn.zero_grad() # clear the gradients
312+
for iter in range(1, n_epoch + 1):
313+
rnn.zero_grad() # clear the gradients
309314

310315
# create some minibatches
311316
# we cannot use dataloaders because each of our names is a different length
312317
batches = list(range(len(training_data)))
313318
random.shuffle(batches)
314319
batches = np.array_split(batches, len(batches) //n_batch_size )
315320

316-
for idx, batch in enumerate(batches):
321+
for idx, batch in enumerate(batches):
317322
batch_loss = 0
318323
for i in batch: #for each example in this batch
319324
(label_tensor, text_tensor, label, text) = training_data[i]
@@ -328,16 +333,16 @@ def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50
328333
optimizer.zero_grad()
329334

330335
current_loss += batch_loss.item() / len(batch)
331-
336+
332337
all_losses.append(current_loss / len(batches) )
333338
if iter % report_every == 0:
334339
print(f"{iter} ({iter / n_epoch:.0%}): \t average batch loss = {all_losses[-1]}")
335340
current_loss = 0
336-
341+
337342
return all_losses
338343

339344
##########################################################################
340-
# We can now train a dataset with minibatches for a specified number of epochs. The number of epochs for this
345+
# We can now train a dataset with minibatches for a specified number of epochs. The number of epochs for this
341346
# example is reduced to speed up the build. You can get better results with different parameters.
342347

343348
start = time.time()
@@ -373,12 +378,12 @@ def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50
373378

374379
def evaluate(rnn, testing_data, classes):
375380
confusion = torch.zeros(len(classes), len(classes))
376-
381+
377382
rnn.eval() #set to eval mode
378383
with torch.no_grad(): # do not record the gradients during eval phase
379384
for i in range(len(testing_data)):
380385
(label_tensor, text_tensor, label, text) = testing_data[i]
381-
output = rnn(text_tensor)
386+
output = rnn(text_tensor)
382387
guess, guess_i = label_from_output(output, classes)
383388
label_i = classes.index(label)
384389
confusion[label_i][guess_i] += 1
@@ -409,7 +414,7 @@ def evaluate(rnn, testing_data, classes):
409414

410415

411416
evaluate(rnn, test_set, classes=alldata.labels_uniq)
412-
417+
413418

414419
######################################################################
415420
# You can pick out bright spots off the main axis that show which
@@ -429,7 +434,7 @@ def evaluate(rnn, testing_data, classes):
429434
# - Try the ``nn.LSTM`` and ``nn.GRU`` layers
430435
# - Modify the size of the layers, such as increasing or decreasing the number of hidden nodes or adding additional linear layers
431436
# - Combine multiple of these RNNs as a higher level network
432-
#
437+
#
433438
# - Try with a different dataset of line -> label, for example:
434439
#
435440
# - Any word -> language

0 commit comments

Comments
 (0)