diff --git a/.gitignore b/.gitignore index 47e4ebc..1173b16 100644 --- a/.gitignore +++ b/.gitignore @@ -7,11 +7,9 @@ data/* */data/* config/* */runs/* -*/EMNISTNet/data/* -*/EMNISTNet/custom_dataset/* -*/EMNISTNet/mini_custom_dataset/* runs/ debug/* +ssigalpr_samples/ #*.jpg #*.JPG #*.jpeg diff --git a/README.md b/README.md index 08cbd83..338149e 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,9 @@ # What's this repo about? -This is a simple approach for vehicle registration plate detection and recognition. It is not an end-to-end system, instead, three different methods were stacked together to complete this task. [*YOLO*](https://github.com/pjreddie/darknet) object detection algorithm was used to detect license plate regions, then a marker-based segmentation method using watershed algorithm was applied to extract the character digits. After that, a Convolutional Neural Network (CNN) - *EMNISTNet* - and the "vanilla" [*Tesseract-OCR*](https://github.com/tesseract-ocr/tesseract) Optical Character Recognition (OCR) were used to recognize the extracted digits. +This is a simple approach for vehicle registration plate detection and recognition. It is not an end-to-end system, instead, two different deep learning methods were stacked together to complete this task. [*YOLO*](https://github.com/AlexeyAB/darknet) object detection algorithm was used to detect license plate regions, then an `Attention Based Optical Character Recognition` [*Attention-OCR*](https://github.com/wptoux/attention-ocr) was applied to recognize the characters. -![Output](docs/result.jpg "Output")*Output: vehicle license plate and recognized digits were blurred for an obvious reason.* - -Note that it is far from being a perfect solution to this problem. Although YOLO does a great job of finding the license plate regions and character recognition is pretty straight forward nowadays, further improvements could be made. For instance, the character segmentation method used here gives poor results for noisy images, and thus, decreasing OCR accuracy. One could address this issue by applying other image processing algorithms, such as image equalization, morphological operations, among others, to improve image quality and remove as much as possible of the undesired image parts. +![Output](docs/result.jpg "Output")*Results (vehicle license plate and recognized characters were intentionally blurred).* # Install and Requirements @@ -16,30 +14,13 @@ Note that it is far from being a perfect solution to this problem. Although YOLO pip install -r requirements.txt ```` -## Tesseract-OCR (optional) - -If you also want to use *Tesseract-OCR* for the character recognition task, follow the instructions below: - -* Tesseract-OCR binaries: -```` -sudo apt update sudo apt install tesseract-ocr -```` - -* Tesseract-OCR Python API: -```` -pip install pytesseract==0.3.3 -```` - ## Pre-trained Weights -Download the pre-trained weights for the YOLO and EMNISTNet and put it in the `config` directory. +Download the pre-trained weights for the YOLO and the Attention-OCR and put it in the `config` directory. -* *YOLO* was trained on the Brazilian [SSIG-ALPR](http://smartsenselab.dcc.ufmg.br/en/dataset/banco-de-dados-sense-alpr/) dataset. +* *YOLO* and *Attention-OCR* were trained on the Brazilian [SSIG-ALPR](http://smartsenselab.dcc.ufmg.br/en/dataset/banco-de-dados-sense-alpr/) dataset. * `TODO:` upload weights and other config files somewhere. -* *EMNISTNet* was trained on the [EMNIST](https://www.nist.gov/itl/products-and-services/emnist-dataset) `bymerge` dataset until it reaches around 89% of accuracy, then training was continued with a custom dataset for fine-tuning. (`TODO:` link the custom dataset). - * `TODO:` upload weights - # Running Run the application API: @@ -58,78 +39,36 @@ curl --location --request POST 'localhost:5000/' \ ### API Output: -Although multiple detections and recognitions are possible in the same image, the API will output the prediction for the detection with the highest confidence. +The API will output all the detections with the corresponding bounding boxes and its confidence scores as well as the OCR prediction for each bounding box. Also, we draw all these information on the input image and outputs it as a base64 image. -`json object` response: +`json object` response will look like the following: ```` { - "bounding_box": { - "h": 51, - "w": 127, - "x": 1474, - "y": 520 - }, - "classId": "0", - "confidence": 1.0, - "emnist_net_preds": "ABC1234", - "tesseract_preds": "ABC1234" + "detections": [ + { + "bb_confidence": 0.973590612411499, + "bounding_box": [ + 1509, + 877, + 82, + 39 + ], + "ocr_pred": "ABC1234-" + }, + { + "bb_confidence": 0.9556514024734497, + "bounding_box": [ + 161, + 866, + 100, + 40 + ], + "ocr_pred": "ABC1234-" + } + ], + "output_image": "/9j/4AAQS..." } ```` *Note: If `DEBUG` flag is set to `True` in the `app.py`, images will be produced in the `debug` directory to make debug a bit easier.* - -# How To Train - -If you want to train the models by yourself, or just want to use your custom datasets, just follow the instructions below: - -## YOLO - -* You can find [here](https://github.com/AlexeyAB/darknet) very clear instructions on how to train YOLO on your dataset. - -## EMNISTNet - -Go the EMNISTNet directory and simply type: -```` -python train_model.py --e=5 --cuda --v -```` - -* Params: - * --e=number_of_epochs: the number of epochs you want to train your model - * --cuda: if you want to train on GPU that supports CUDA - * --v: verbose mode - -### Fine-tuning on a custom dataset - -As we know the EMNIST is a handwritten character digits dataset and the extracted digits of license plates are not handwritten, so EMNISTNet may not give the desired accuracy on these particular images. To circumvent this issue, training was carried out on a custom dataset where digits are more like to our problem domain. `Data Augmentation` methods, such as `rotation` and `shear`, was also applied. - -
- -
Custom dataset: examples of character digits.
-
- - -```` -python train_model.py --m=emnist_model.pt --d=custom_dataset/ --e=10 --cuda --v -```` - -* Params: - * --m=previous_model.pt: start weights from a pre-trained model and continue training from there - * --d=path_to_the_custom_dataset: path to our custom dataset - -*Note: Since pytorch DataLoader keeps its own internal class indexes for the target labels based on the alphabetical order, and image subdirectories are used as class labels, in order to keep track of the `idx` like this:* - -`idx = ['0','1','2','3','4','5','6','7','8','9','A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z']` - -*I managed to put images in the custom dataset as shown below:* - -```` -root_image_dir/a/0_image1.png -root_image_dir/a/0_image2.png -root_image_dir/a/0_imageN.png -root_image_dir/ab/1_imageN.png -root_image_dir/abc/2_imageN.png -root_image_dir/.../..._imageN.png -root_image_dir/abcdefghijklmnopqrstuvwxyzabcdefghi/Y_imageN.png -root_image_dir/abcdefghijklmnopqrstuvwxyzabcdefghij/Z_imageN.png -```` diff --git a/requirements.txt b/requirements.txt index ad40796..5548ed9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,5 +15,7 @@ Flask==1.1.1 imutils==0.5.3 scikit-image==0.16.2 tensorboard==1.14.0 -torch==1.4.0+cpu -torchvision==0.5.0+cpu \ No newline at end of file +torch==1.4.0 +torchvision==0.5.0 +tqdm==4.46.1 +Pillow==7.1.1 \ No newline at end of file diff --git a/src/EMNISTNet/exploring_custom_dataset.py b/src/EMNISTNet/exploring_custom_dataset.py deleted file mode 100644 index a5b6218..0000000 --- a/src/EMNISTNet/exploring_custom_dataset.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -import torchvision -import torchvision.transforms as transforms -import numpy as np -import matplotlib.pyplot as plt - -def plot_images(data, rows, cols, cmap='gray'): - if(len(data) > 0): - i = 0 - for title, image in data.items(): - #logging.debug(title) - plt.subplot(rows,cols,i+1),plt.imshow(image,cmap) - plt.title(title) - plt.xticks([]),plt.yticks([]) - i += 1 - plt.show() - -def display_images(img_list, row, col): - if(len(img_list) > 0): - images = {} - n = 0 - for img in img_list: - n += 1 - images[str(n)] = img - plot_images(images, row, col, cmap='gray') - -train_data = torchvision.datasets.ImageFolder(root='custom_dataset/', - transform=transforms.Compose([ - transforms.Grayscale(num_output_channels=1), - transforms.RandomApply([transforms.RandomAffine(degrees=(-30, 30), shear=(-30, 30))], p=1.0), - transforms.ToTensor() - ]) - ) -print(f'dataset size: {len(train_data)}') - -NUM_IMAGES = 36 - -groundtruth = ['0','1','2','3','4','5','6','7','8','9', - 'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z'] - -data_loader = torch.utils.data.DataLoader(train_data, batch_size=NUM_IMAGES, shuffle=True) -batch = next(iter(data_loader)) -print(f'batch len: {len(batch)}') -print(f'type: {type(batch)}') -images, labels = batch -print(f'batch size: {len(images)}') -print(f'images shape: {images.shape}') -print(f'labels shape: {labels.shape}') -print(f'labels: {labels}') -print(f'pixels type:\n {type(images[0][0][0][0])}') -print(f'pixels max and min values:\n {torch.max(images[0][0])} and {torch.min(images[0][0])}') -print(f'pixels max and min values:\n {torch.max(images)} and {torch.min(images)}') - -groundtruth_labels_indexes = list(np.array(labels.squeeze(0)).astype(int)) -groundtruth_classes_name = [groundtruth[idx] for idx in groundtruth_labels_indexes] -print(f'groundtruth classes: {groundtruth_classes_name}') - -# plotting images -images = [ images[idx][0].numpy() for idx in range(NUM_IMAGES)] -display_images(images, 1, 36) \ No newline at end of file diff --git a/src/EMNISTNet/exploring_emnist.py b/src/EMNISTNet/exploring_emnist.py deleted file mode 100644 index f08abc8..0000000 --- a/src/EMNISTNet/exploring_emnist.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch -import torchvision -import torchvision.transforms as transforms -import numpy as np -import matplotlib.pyplot as plt - -def plot_images(data, rows, cols, cmap='gray'): - if(len(data) > 0): - i = 0 - for title, image in data.items(): - #logging.debug(title) - plt.subplot(rows,cols,i+1),plt.imshow(image,cmap) - plt.title(title) - plt.xticks([]),plt.yticks([]) - i += 1 - plt.show() - -def display_images(img_list, row, col): - if(len(img_list) > 0): - images = {} - n = 0 - for img in img_list: - n += 1 - images[str(n)] = img - plot_images(images, row, col, cmap='gray') - -train_data = torchvision.datasets.EMNIST( - root = 'data/', - split='bymerge', - train = True, - download = True, - transform=transforms.Compose([ - #transforms.RandomApply([transforms.RandomAffine(degrees=(-30, 30), shear=(-30, 30)), - #transforms.Pad(padding=1, fill=0, padding_mode='constant')], p=1.0), - #transforms.RandomHorizontalFlip(p=1.0), - #transforms.RandomVerticalFlip(p=1.0), - #transforms.RandomPerspective(p=1.0), - transforms.ToTensor() - ]) - ) - -NUM_IMAGES = 20 - -groundtruth = ['0','1','2','3','4','5','6','7','8','9', - 'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z', - 'a','b','d','e','f','g','h','n','q','r','t'] - -data_loader = torch.utils.data.DataLoader(train_data, batch_size=NUM_IMAGES, shuffle=True) -batch = next(iter(data_loader)) -print(f'batch len: {len(batch)}') -print(f'type: {type(batch)}') -images, labels = batch -print(f'images shape: {images.shape}') -print(f'labels shape: {labels.shape}') -print(f'labels: {labels}') -print(f'pixels type:\n {type(images[0][0][0][0])}') -print(f'pixels max and min values:\n {torch.max(images[0][0])} and {torch.min(images[0][0])}') -#print(f'pixels:\n {images[0][0]}') -# plotting images -#grid = torchvision.utils.make_grid(images, nrow=NUM_IMAGES) -#plt.figure(figsize=(15,15)) -#plt.imshow(np.transpose(grid, (1,2,0))) -#plt.show() - -groundtruth_labels_indexes = list(np.array(labels.squeeze(0)).astype(int)) -groundtruth_classes_name = [groundtruth[idx] for idx in groundtruth_labels_indexes] -print(f'groundtruth classes: {groundtruth_classes_name}') - -# plotting images -images = [ images[idx][0].numpy() for idx in range(NUM_IMAGES)] -display_images(images, 1, NUM_IMAGES) \ No newline at end of file diff --git a/src/EMNISTNet/models.py b/src/EMNISTNet/models.py deleted file mode 100644 index efaafec..0000000 --- a/src/EMNISTNet/models.py +++ /dev/null @@ -1,189 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -class EMNISTNet(torch.nn.Module): - def __init__(self, num_classes): - super(EMNISTNet, self).__init__() - - # conv output size = ((inputSize + 2*pad - filterSize) / stride) + 1 - # max pool with filterSize = 2 and stride = 2 shrinks down by half - # - # input: 28x28x1 - # output: (((28 + 2*0 - 3) / 1) + 1) = (28 - 3) + 1 = 26 --> 26x26x16 - self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3) - # input: 26x26x16 - # output: (26 - 3) + 1 = 24 - # max pool: floor(24/2) = 12 --> 12x12x64 - self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3) - # input: 12x12x64 - # output: (12 - 3) + 1 = 10 - # max pool: floor(10/2) = 5 --> 5x5x64 - self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3) - - # fc - # input: 5x5x64 --> flatten = 1600 - self.fc1 = nn.Linear(in_features=64*5*5, out_features=64*5*5) - self.fc2 = nn.Linear(in_features=64*5*5, out_features=64*4*2) - self.out = nn.Linear(in_features=64*4*2, out_features=num_classes) - - def forward(self, x): - # conv - # in: 28x28x1 - x = F.relu(self.conv1(x)) - - # in: 26x26x16 - x = F.relu(self.conv2(x)) - x = F.max_pool2d(x, kernel_size=2, stride=2) - - # in: 12x12x64 - x = F.relu(self.conv3(x)) - x = F.max_pool2d(x, kernel_size=2, stride=2) - - # fc - # in: 5x5x64 - x = x.reshape(-1, 64*5*5) - - x = F.relu(self.fc1(x)) - - x = F.relu(self.fc2(x)) - - x = self.out(x) - - return x - -class EMNISTNet_v2(torch.nn.Module): - def __init__(self, num_classes): - super(EMNISTNet_v2, self).__init__() - - # conv - self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3) - self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3) - self.conv2_drop = nn.Dropout2d(p=0.05) - - self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3) - self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3) - self.conv4_drop = nn.Dropout2d(p=0.05) - - # fc - self.fc1 = nn.Linear(in_features=64*4*4, out_features=64*4*2) - self.fc2 = nn.Linear(in_features=64*4*2, out_features=64*4) - self.fc2_drop = nn.Dropout(p=0.1) - self.out = nn.Linear(in_features=64*4, out_features=num_classes) - - def forward(self, x): - # output size = ((inputSize + 2*pad - filterSize) / stride) + 1 - # conv - # in: 28x28x1 - x = F.relu(self.conv1(x)) - - # in: 26x26x32 - x = F.max_pool2d(F.relu(self.conv2(x)), kernel_size=2, stride=2) - - x = self.conv2_drop(x) - - # in: 12x12x64 - x = F.relu(self.conv3(x)) - - # in: 10x10x64 - x = F.max_pool2d(F.relu(self.conv4(x)), kernel_size=2, stride=2) - - x = self.conv4_drop(x) - - # fc - # in: 4x4x64 - x = x.reshape(-1, 64*4*4) - x = F.relu(self.fc1(x)) - - x = F.relu(self.fc2(x)) - - x = self.fc2_drop(x) - - x = self.out(x) - - return x - -class EMNISTNet_v3(torch.nn.Module): - def __init__(self, num_classes): - super(EMNISTNet_v3, self).__init__() - - # conv - self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3) - self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3) - self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3) - self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3) - self.conv5 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2) - self.conv6 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2) - - # fc - self.fc1 = nn.Linear(in_features=64*1*1, out_features=64*4) - self.fc2 = nn.Linear(in_features=64*4, out_features=64*4) - self.out = nn.Linear(in_features=64*4, out_features=num_classes) - - def forward(self, x): - # output size = ((inputSize + 2*pad - filterSize) / stride) + 1 - # conv - # in: 28x28x1 - x = F.relu(self.conv1(x)) - - # in: 26x26x32 - x = F.max_pool2d(F.relu(self.conv2(x)), kernel_size=2, stride=2) - - # in: 12x12x64 - x = F.relu(self.conv3(x)) - - # in: 10x10x64 - x = F.max_pool2d(F.relu(self.conv4(x)), kernel_size=2, stride=2) - - # in: 4x4x64 - x = F.relu(self.conv5(x)) - - # in: 3x3x64 - x = F.max_pool2d(F.relu(self.conv6(x)), kernel_size=2, stride=2) - - # fc - # in: 1x1x64 - x = x.reshape(-1, 64*1*1) - x = F.relu(self.fc1(x)) - - x = F.relu(self.fc2(x)) - - x = self.out(x) - - return x - -class EMNISTNet_v4(torch.nn.Module): - def __init__(self, num_classes): - super(EMNISTNet_v4, self).__init__() - - # conv - self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=5, padding=2) - self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, padding=2) - - # fc - self.fc1 = nn.Linear(in_features=128*7*7, out_features=1024) - self.fc2 = nn.Linear(in_features=1024, out_features=128) - self.fc2_drop = nn.Dropout(p=0.2) - self.out = nn.Linear(in_features=128, out_features=num_classes) - - def forward(self, x): - # output size = ((inputSize + 2*pad - filterSize) / stride) + 1 - # conv - # in: 28x28x1 - x = F.max_pool2d(F.relu(self.conv1(x)), kernel_size=2, stride=2) - - # in: 14x14x64 - x = F.max_pool2d(F.relu(self.conv2(x)), kernel_size=2, stride=2) - - # fc - # in: 7x7x128 - x = x.reshape(-1, 128*7*7) - x = F.relu(self.fc1(x)) - - x = F.relu(self.fc2(x)) - - x = self.fc2_drop(x) - - x = self.out(x) - - return x diff --git a/src/EMNISTNet/pytorch_utils.py b/src/EMNISTNet/pytorch_utils.py deleted file mode 100644 index df37f22..0000000 --- a/src/EMNISTNet/pytorch_utils.py +++ /dev/null @@ -1,140 +0,0 @@ -import torch -import torchvision - -from itertools import product -from collections import namedtuple -from collections import OrderedDict - -from torch.utils.tensorboard import SummaryWriter - -import pandas as pd -import time -import json -import logging - -def info(): - logging.info(torch.__version__) - logging.info(torch.cuda.is_available()) - logging.info(torch.version.cuda) - logging.info(torch.cuda.current_device()) - logging.info(torch.cuda.get_device_name(0)) - -def get_runs_params(params): - Run = namedtuple('Run', params.keys()) - runs = [] - for v in product(*params.values()): - runs.append(Run(*v)) - return runs - -class TrainingManager(): - def __init__(self): - self.epoch_count = 0 - self.epoch_loss = 0 - self.epoch_num_correct = 0 - self.epoch_start_time = None - - self.run_params = None - self.run_count = 0 - self.run_data = [] - self.run_start_time = None - - self.model = None - self.loader = None - self.tb = None - - @torch.no_grad() - def _get_num_correct(self, preds, labels): - return preds.argmax(dim=1).eq(labels).sum().item() - - def begin_run(self, run, model, loader): - self.run_start_time = time.time() - - self.run_params = run - self.run_count += 1 - - self.model = model - self.loader = loader - self.tb = SummaryWriter(comment=f'-{run}') - - self.cached_images, labels = next(iter(self.loader)) # images will be cached to use in tracing when exporting for mobile - grid = torchvision.utils.make_grid(self.cached_images) - self.tb.add_image('images', grid) - - # FIXME(andrey): adding model to tensorboard is crashing when model is running on GPU - #self.tb.add_graph(self.model, images) - - logging.info(f'Start training for run #{self.run_count}:\nlr: {run.lr}\nbatch_size: {run.batch_size}\nshuffle: {run.shuffle}') - - def end_run(self): - self.tb.close() - self.epoch_count = 0 - - def begin_epoch(self): - self.epoch_start_time = time.time() - - self.epoch_count += 1 - self.epoch_loss = 0 - self.epoch_num_correct = 0 - - logging.info(f'Starting epoch: {self.epoch_count}') - - def end_epoch(self, save_model=False): - epoch_duration = time.time() - self.epoch_start_time - run_duration = time.time() - self.run_start_time - - loss = self.epoch_loss / len(self.loader.dataset) - accuracy = self.epoch_num_correct / len(self.loader.dataset) - - logging.info(f'\nFinished epoch {self.epoch_count} in {epoch_duration}s - Accuracy: {accuracy} - Loss: {loss}') - - self.tb.add_scalar('Loss', loss, self.epoch_count) - self.tb.add_scalar('Accuracy', accuracy, self.epoch_count) - - for name, param in self.model.named_parameters(): - self.tb.add_histogram(name, param, self.epoch_count) - self.tb.add_histogram(f'{name}.grad', param.grad, self.epoch_count) - - results = OrderedDict() - results["run"] = self.run_count - results["epoch"] = self.epoch_count - results['loss'] = loss - results["accuracy"] = accuracy - results['epoch duration'] = epoch_duration - results['run duration'] = run_duration - for k,v in self.run_params._asdict().items(): results[k] = v - self.run_data.append(results) - - df = pd.DataFrame.from_dict(self.run_data, orient='columns') - - if save_model: - torch.save(self.model.state_dict(), f'{self.run_count}_{self.epoch_count}_model.pt') - - def track_loss(self, loss): - self.epoch_loss += loss.item() * self.loader.batch_size - - def track_num_corret(self, preds, labels): - self.epoch_num_correct += self._get_num_correct(preds, labels) - - def save(self, filename, save_training_report=False, export_for_mobile_loading=False): - """ - Save results and model. - """ - if save_training_report: - # save csv - pd.DataFrame.from_dict( - self.run_data, - orient='columns' - ).to_csv(f'{filename}.csv') - - # save json - with open(f'{filename}.json', 'w', encoding='utf-8') as f: - json.dump(self.run_data, f, ensure_ascii=False, indent=4) - - - if export_for_mobile_loading: - self.model.to("cpu") - self.cached_images.to("cpu") - traced_cpu = torch.jit.trace(self.model, self.cached_images) - torch.jit.save(traced_cpu, f'{filename}_mobile.pth') - - torch.save(self.model.state_dict(), f'{filename}.pt') \ No newline at end of file diff --git a/src/EMNISTNet/test_model.py b/src/EMNISTNet/test_model.py deleted file mode 100644 index 889f113..0000000 --- a/src/EMNISTNet/test_model.py +++ /dev/null @@ -1,122 +0,0 @@ -import torch -import torchvision -import torch.nn.functional as F -import torchvision.transforms as transforms -from torch.utils.data import DataLoader - -import matplotlib.pyplot as plt -import numpy as np - -import argparse -import logging - -from models import EMNISTNet - -def plot_images(data, rows, cols, cmap='gray'): - if(len(data) > 0): - i = 0 - for title, image in data.items(): - #logging.debug(title) - plt.subplot(rows,cols,i+1),plt.imshow(image,cmap) - plt.title(title) - plt.xticks([]),plt.yticks([]) - i += 1 - plt.show() - -def display_images(img_list, row, col): - if(len(img_list) > 0): - images = {} - n = 0 - for img in img_list: - n += 1 - images[str(n)] = img - plot_images(images, row, col, cmap='gray') - -def str2bool(v): - if isinstance(v, bool): - return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - -def load_dataset(path_filename): - if path_filename is not None: - dataset = torchvision.datasets.ImageFolder(root=path_filename, - transform=transforms.Compose([ - transforms.Grayscale(num_output_channels=1), - transforms.ToTensor() - ]) - ) - logging.debug(f'Training on {path_filename} dataset, size: {len(dataset)}') - else : - dataset = torchvision.datasets.EMNIST( - root = 'data/', - split='bymerge', - train = True, - download = True, - transform = transforms.Compose([ - transforms.ToTensor() - ]) - ) - logging.debug(f'Training on EMNIST (bymerge) dataset, size: {len(dataset)}') - - data_loader = DataLoader(dataset, batch_size=1000, shuffle=True) - batch = next(iter(data_loader)) - images, labels = batch - num_classes = int(torch.max(labels)) + 1 - logging.debug(f'Number of classes: {num_classes}') - - return dataset, num_classes - -if __name__ == '__main__': - - parser = argparse.ArgumentParser(description='Usage: python test.py --m=result_model.pt --v') - parser.add_argument('--m', help='Path to model') - parser.add_argument('--v', type=str2bool, nargs='?', const=True, default=False, help='verbose and debug msgs') - parser.add_argument('--d', help='dataset to test on') - args = parser.parse_args() - - if args.v: - DEBUG = True - logging.getLogger().setLevel(logging.DEBUG) - logging.debug('Verbose mode is activated.') - - test_set, num_classes = load_dataset(args.d) - - classes = ['0','1','2','3','4','5','6','7','8','9', - 'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z', - 'a','b','d','e','f','g','h','n','q','r','t'] - - NUM_IMAGES = 20 - - loader = DataLoader(test_set, batch_size=NUM_IMAGES, shuffle=True) - - batch = next(iter(loader)) - logging.debug(f'batch len: {len(batch)}') - logging.debug(f'type: {type(batch)}') - images, labels = batch - logging.debug(f'images shape: {images.shape}') - logging.debug(f'labels shape: {labels.shape}') - - logging.debug(f'labels: {labels}') - groundtruth_labels_indexes = list(np.array(labels.squeeze(0)).astype(int)) - logging.debug(f'labels_indexes: {groundtruth_labels_indexes}') - groundtruth_classes_name = [classes[idx] for idx in groundtruth_labels_indexes] - logging.debug(f'groundtruth classes: {groundtruth_classes_name}') - - net = EMNISTNet(num_classes=num_classes) - net.load_state_dict(torch.load(args.m)) - - preds = net(images) - preds = preds.argmax(dim=1) - logging.info(f'preds: {preds.shape}\n{preds}') - preds_indexes = list(np.array(preds.squeeze(0)).astype(int)) - preds_classes_name = [classes[idx] for idx in preds_indexes] - logging.debug(f'groundtruth classes: {groundtruth_classes_name}') - logging.debug(f'preds_classes : {preds_classes_name}') - - images = [ images[idx][0].numpy() for idx in range(NUM_IMAGES)] - display_images(images, 1, NUM_IMAGES) diff --git a/src/EMNISTNet/train_model.py b/src/EMNISTNet/train_model.py deleted file mode 100644 index e233e76..0000000 --- a/src/EMNISTNet/train_model.py +++ /dev/null @@ -1,151 +0,0 @@ -import torch -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F -import torchvision -import torchvision.transforms as transforms -from torch.utils.data import DataLoader - -from collections import OrderedDict - -import argparse -import logging - -from pytorch_utils import TrainingManager, get_runs_params - -from models import EMNISTNet, EMNISTNet_v2, EMNISTNet_v3, EMNISTNet_v4 - -torch.set_printoptions(linewidth=120) -torch.set_grad_enabled(True) - -logging.getLogger().setLevel(logging.INFO) -DEBUG = False - -def str2bool(v): - if isinstance(v, bool): - return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - -def load_dataset(path_filename): - if path_filename is not None: - dataset = torchvision.datasets.ImageFolder(root=path_filename, - transform=transforms.Compose([ - transforms.Grayscale(num_output_channels=1), - transforms.RandomApply([transforms.RandomAffine(degrees=(-20, 20), shear=(-30, 30))], p=0.5), - transforms.ToTensor() - ]) - ) - logging.debug(f'Training on {path_filename} dataset, size: {len(dataset)}') - else : - dataset = torchvision.datasets.EMNIST( - root = 'data/', - split='bymerge', - train = True, - download = True, - transform = transforms.Compose([ - transforms.ToTensor() - ]) - ) - logging.debug(f'Training on EMNIST (bymerge) dataset, size: {len(dataset)}') - - data_loader = DataLoader(dataset, batch_size=1000, shuffle=True) - batch = next(iter(data_loader)) - images, labels = batch - num_classes = int(torch.max(labels)) + 1 - logging.debug(f'Number of classes: {num_classes}') - - return dataset, num_classes - -def load_net(net_version, pretrained_weights): - if net_version is not None: - if net_version == 1: - net = EMNISTNet(num_classes=num_classes) - elif net_version == 2: - net = EMNISTNet_v2(num_classes=num_classes) - elif net_version == 3: - net = EMNISTNet_v3(num_classes=num_classes) - elif net_version == 4: - net = EMNISTNet_v4(num_classes=num_classes) - else: - net = EMNISTNet(num_classes=num_classes) - else: - net = EMNISTNet(num_classes=num_classes) - - if pretrained_weights is not None: - logging.debug(f'Loading pre-trained model: {pretrained_weights}') - net.load_state_dict(torch.load(pretrained_weights)) - - return net - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Usage: python train_model.py --m=emnist_model.pt --d=custom_dataset/ --e=1 --cuda --v --o=emnist_model') - parser.add_argument('--m', help='Path to a previous model to start with') - parser.add_argument('--e', type=int, nargs='?', const=1, default=1, help='Number of epochs to train the model') - parser.add_argument('--cuda', type=str2bool, nargs='?', const=False, default=False, help='use CUDA if available') - parser.add_argument('--v', type=str2bool, nargs='?', const=False, default=False, help='verbose and debug msgs') - parser.add_argument('--d', help='Path to the custom dataset') - parser.add_argument('--o', help='Output model filename') - parser.add_argument('--mobile', type=str2bool, nargs='?', const=False, default=False, help='export model for mobile loading') - parser.add_argument('--n', type=int, nargs='?', const=1, default=1, help='net model to use') - parser.add_argument('--b', type=int, nargs='?', const=500, default=500, help='batch size') - args = parser.parse_args() - - if args.v: - DEBUG = True - logging.getLogger().setLevel(logging.DEBUG) - logging.debug('Verbose mode is activated.') - - NUM_EPOCHS = args.e - logging.debug(f'Number of epochs: {NUM_EPOCHS}') - - train_set, num_classes = load_dataset(args.d) - - # TODO: put run params in a config file or just remove multiple runs support? - params = OrderedDict( - lr = [0.001], - batch_size = [args.b], - shuffle = [True] - ) - - device = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda else "cpu") - logging.info(f'Using device {device}') - - train_manager = TrainingManager() - for run in get_runs_params(params): - net = load_net(args.n, args.m) - net.to(device) - - data_loader = DataLoader(train_set, batch_size=run.batch_size, shuffle=run.shuffle) - optimizer = optim.Adam(net.parameters(), lr=run.lr) - - train_manager.begin_run(run, net, data_loader) - - for epoch in range(NUM_EPOCHS): - train_manager.begin_epoch() - - for batch in data_loader: - images, labels = batch - images, labels = images.to(device), labels.to(device) - - preds = net(images) - loss = F.cross_entropy(preds, labels) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - train_manager.track_loss(loss) - train_manager.track_num_corret(preds, labels) - - logging.info(f'Loss: {loss}') - - train_manager.end_epoch() - - train_manager.end_run() - - if args.o is not None: - train_manager.save(filename=args.o, export_for_mobile_loading=True) diff --git a/src/app.py b/src/app.py index 1561561..f538993 100644 --- a/src/app.py +++ b/src/app.py @@ -5,16 +5,9 @@ import logging -from image_processing import extract_chars from yolo import Yolo from ocr import OCR -import importlib -tesseract_spec = importlib.util.find_spec("pytesseract") -tesseract_found = tesseract_spec is not None -if tesseract_found: - import pytesseract - app = Flask(__name__) DEBUG = True @@ -46,73 +39,50 @@ def run_lpr(): roi_imgs = yolo.detect(inputImage) - ocr = OCR(model_filename="../config/emnist_net_custom.pt", num_classes=36, use_cuda=False, debug=DEBUG) + ocr = OCR(model_filename="../config/attention_ocr_model.pth", use_cuda=False, threshold=0.7) index = 0 + api_output = [] for roi_img in roi_imgs: logging.info(f'\n\nProcessing ROI {index}') box = [yolo.bounding_boxes[index][0], yolo.bounding_boxes[index][1], yolo.bounding_boxes[index][2], yolo.bounding_boxes[index][3]] - predict(yolo.img, roi_img, box, str(index), (0,255,0), ocr) + score = yolo.confidences[index] + pred = ocr.predict(roi_img) + + draw_bounding_box(input_image=yolo.img, bounding_box=box, label=pred, background_color=(0,255,0), ocr=ocr) + logging.info(f'\nOCR output: {pred}') + + output = {'bounding_box' : box, 'bb_confidence' : score, 'ocr_pred' : pred} + api_output.append(output) + if(DEBUG): cv.imwrite("../debug/roi_"+str(index)+".jpg", roi_img.astype(np.uint8)) - + index += 1 - - # API response: the highest confidence one - logging.info(f'\n\n---Processing the Highest Confidence ROI---\n') - bounding_box = None - emnist_net_preds = None - tesseract_preds = None - if(yolo.highest_object_confidence > 0 and yolo.roi_img is not None): - bounding_box = { - 'x': yolo.box_x, - 'y': yolo.box_y, - 'w': yolo.box_w, - 'h': yolo.box_h - } - _, emnist_net_preds, tesseract_preds = predict(yolo.img, yolo.roi_img, [yolo.box_x, yolo.box_y, yolo.box_w, yolo.box_h], "", (255,255,0), ocr) - if(DEBUG): - cv.imwrite("../debug/result.jpg", yolo.img.astype(np.uint8)) - - data = { - 'bounding_box': bounding_box, - 'confidence': yolo.highest_object_confidence, - 'classId': str(yolo.classId_highest_object), - 'emnist_net_preds': emnist_net_preds, - 'tesseract_preds': tesseract_preds + + if(DEBUG): + cv.imwrite("../debug/result.jpg", yolo.img.astype(np.uint8)) + + success, output_image = cv.imencode('.jpg', yolo.img) + api_response = { + 'output_image' : base64.b64encode(output_image).decode('utf-8'), + 'detections' : api_output } - response = jsonify(data) + response = jsonify(api_response) response.status_code = 200 return response -def predict(input_image, roi_img, bounding_box, prefix_label, background_color, emnist_net): - characteres, img, mask = extract_chars(roi_img) - - emnist_net_preds = emnist_net.predict(characteres) - - tesseract_preds = None - if tesseract_found: - tesseract_preds = pytesseract.image_to_string(img, lang='eng', config='--oem 3 --psm 13 -c tessedit_char_whitelist=0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ') - - logging.debug(f'\nTesseract output: {tesseract_preds}\nEMNISTNet output: {emnist_net_preds}') - - text = tesseract_preds if tesseract_preds is not None else emnist_net_preds - labelSize, baseLine = cv.getTextSize(text, cv.FONT_HERSHEY_SIMPLEX, 0.6, 2) +def draw_bounding_box(input_image, bounding_box, label, background_color, ocr): + labelSize, baseLine = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, 0.6, 2) x = bounding_box[0] y = bounding_box[1] + bounding_box[3] w = bounding_box[0] + round(1.1*labelSize[0]) h = (bounding_box[1] + bounding_box[3]) + 25 cv.rectangle(input_image, (x, y), (w, h), background_color, cv.FILLED) - cv.putText(input_image, text, (x+5, y+20), cv.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2) - - if(DEBUG): - cv.imwrite("../debug/roi_masked_"+prefix_label+".jpg", img.astype(np.uint8)) - cv.imwrite("../debug/roi_mask_"+prefix_label+".jpg", mask.astype(np.uint8)) - - return characteres, emnist_net_preds, tesseract_preds + cv.putText(input_image, label, (x+5, y+20), cv.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2) if __name__ == '__main__': diff --git a/src/attention_ocr/LICENSE b/src/attention_ocr/LICENSE new file mode 100644 index 0000000..0435578 --- /dev/null +++ b/src/attention_ocr/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 Zhen Wang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/attention_ocr/README.md b/src/attention_ocr/README.md new file mode 100644 index 0000000..971c940 --- /dev/null +++ b/src/attention_ocr/README.md @@ -0,0 +1,29 @@ +# attention-ocr +A pytorch implementation of attention based ocr + +This repo is still under development. + +Inspired by the tensorflow attention ocr created by google. [link](https://github.com/tensorflow/models/tree/master/research/attention_ocr) + +More details can also be found in this paper: + +["Attention-based Extraction of Structured Information from Street View Imagery"](https://arxiv.org/abs/1704.03549) + +# Install and Requirements + +### pycrypto for Python 3.6, Windows 10, Visual Studio 2017: + +1. open "x86_x64 Cross-Tools Command Prompt for VS 2017" with administrator privilege in start menu. +2. go to C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Tools\MSVC and check your MSVC version (mine was 14.16.27023) +3. type set CL=-FI"C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Tools\MSVC\14.16.27023\include\stdint.h" with the version you've just found +(also typed it in the vscode env terminal and in the x86_x64 Cross-Tools Command Prompt for VS 2017 as well...) +4. simply pip install pycrypto +5. No module named 'winrandom' when using pycrypto: +Problem is solved by editing string in crypto\Random\OSRNG\nt.py: +```` +import winrandom +```` +to +```` +from . import winrandom +```` \ No newline at end of file diff --git a/src/attention_ocr/__init__.py b/src/attention_ocr/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/attention_ocr/explore_ssigalpr_dataset.py b/src/attention_ocr/explore_ssigalpr_dataset.py new file mode 100644 index 0000000..eeadfa9 --- /dev/null +++ b/src/attention_ocr/explore_ssigalpr_dataset.py @@ -0,0 +1,56 @@ +import numpy as np +import os +import pandas as pd + +from utils.dataset import SSIGALPRDataset +from utils.img_util import display_images + +from torchvision import transforms + +from PIL import Image + +ANNOTADED_FILE = 'ssigalpr_samples/test_train.csv' +IMG_DIR = 'ssigalpr_samples/train/' + +IMG_WIDTH = 160 +IMG_HEIGHT = 60 +N_CHARS = 7 +CHARS = list('1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ') + + +img_trans = transforms.Compose([ + transforms.Resize((IMG_HEIGHT, IMG_WIDTH)) + ,transforms.Grayscale(num_output_channels=3) + ,transforms.ToTensor() + ,transforms.Normalize(mean=[0.5, 0.5, 0.5], std=(0.5, 0.5, 0.5)) +]) + +if __name__ == '__main__': + df = pd.read_csv(ANNOTADED_FILE, dtype={'img_id': str}) + print(f'dataframe shape: {df.shape}') + print(f'total items: {df.shape[0]}') + + df = df.loc[df['text'] != 'no_one'] + print(f'total items after cleaning: {df.shape[0]}') + + annotaded_data = df.iloc[0] + + img_id = annotaded_data.iloc[0] + print(f'image: {img_id}.png') + img = Image.open(IMG_DIR+img_id+'.png') + img.show() + + width, height = img.size + x0 = annotaded_data.iloc[1] * width + y0 = annotaded_data.iloc[2] * height + x1 = annotaded_data.iloc[3] * width + y1 = annotaded_data.iloc[4] * height + + label = annotaded_data.iloc[5] + print(f'label: {label}') + + roi = img.crop((x0, y0, x1, y1)) + roi.show() + + t = img_trans(roi) + display_images(t.numpy(), 1, 3) \ No newline at end of file diff --git a/src/attention_ocr/export_to_mobile.py b/src/attention_ocr/export_to_mobile.py new file mode 100644 index 0000000..17dd61f --- /dev/null +++ b/src/attention_ocr/export_to_mobile.py @@ -0,0 +1,46 @@ +import torch +import argparse + +from torchvision import transforms + +from utils.dataset import SSIGALPRDataset +from utils.tokenizer import Tokenizer + +from model.attention_ocr import AttentionOCR + +ROOT_IMG_PATH = 'ssigalpr_samples/val/' +ANNOTADED_FILE = 'ssigalpr_samples/val.csv' + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Usage: python export_to_mobile.py --in=chkpoint/time_2020-06-12_19-31-05_epoch_10.pth --out=exported_model.pth') + parser.add_argument('--m', help='input model filename.') + parser.add_argument('--out', help='output model filename.') + parser.add_argument('--w', type=int, nargs='?', const=160, default=160, help='image width that the model was trained on.') + parser.add_argument('--h', type=int, nargs='?', const=60, default=60, help='image height that the model was trained on.') + args = parser.parse_args() + + img_width = args.w if args.w is not None else 160 + img_height = args.h if args.h is not None else 60 + nh = 512 + chars = list('1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ') + n_chars = 7 + device = 'cpu' + + tokenizer = Tokenizer(chars) + model = AttentionOCR(img_width, img_height, nh, tokenizer.n_token, + n_chars + 1, tokenizer.SOS_token, tokenizer.EOS_token).to(device=device) + + model.load_state_dict(torch.load(args.m)) + + dataset = SSIGALPRDataset(img_width, img_height, n_chars=n_chars, labels_path=ANNOTADED_FILE, root_img_dir=ROOT_IMG_PATH) + + img, label = dataset[0] + + input_img = img.unsqueeze(0) + input_img.to(device) + + model = AttentionOCR(img_width, img_height, nh, tokenizer.n_token, + n_chars + 1, tokenizer.SOS_token, tokenizer.EOS_token).to(device=device) + + traced_cpu_model = torch.jit.trace(model, input_img) + torch.jit.save(traced_cpu_model, f'{args.out}') diff --git a/src/attention_ocr/model/attention_ocr.py b/src/attention_ocr/model/attention_ocr.py new file mode 100644 index 0000000..04063e3 --- /dev/null +++ b/src/attention_ocr/model/attention_ocr.py @@ -0,0 +1,227 @@ +import math +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torchvision.models.inception import BasicConv2d, InceptionA + + +class MyIncept(nn.Module): + def __init__(self): + super(MyIncept, self).__init__() + self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) + self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) + self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) + self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) + self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) + self.Mixed_5b = InceptionA(192, pool_features=32) + self.Mixed_5c = InceptionA(256, pool_features=64) + self.Mixed_5d = InceptionA(288, pool_features=64) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + import scipy.stats as stats + stddev = m.stddev if hasattr(m, 'stddev') else 0.1 + X = stats.truncnorm(-2, 2, scale=stddev) + values = torch.Tensor(X.rvs(m.weight.numel())) + values = values.view(m.weight.size()) + m.weight.data.copy_(values) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + # 299 x 299 x 3 + x = self.Conv2d_1a_3x3(x) + # 149 x 149 x 32 + x = self.Conv2d_2a_3x3(x) + # 147 x 147 x 32 + x = self.Conv2d_2b_3x3(x) + # 147 x 147 x 64 + x = F.max_pool2d(x, kernel_size=3, stride=2) + # 73 x 73 x 64 + x = self.Conv2d_3b_1x1(x) + # 73 x 73 x 80 + x = self.Conv2d_4a_3x3(x) + # 71 x 71 x 192 + x = F.max_pool2d(x, kernel_size=3, stride=2) + # 35 x 35 x 192 + x = self.Mixed_5b(x) + # 35 x 35 x 256 + x = self.Mixed_5c(x) + # 35 x 35 x 288 + x = self.Mixed_5d(x) + + return x + + +class OneHot(nn.Module): + def __init__(self, depth): + super(OneHot, self).__init__() + emb = nn.Embedding(depth, depth) + emb.weight.data = torch.eye(depth) + emb.weight.requires_grad = False + self.emb = emb + + def forward(self, input_): + return self.emb(input_) + + +class Attention(nn.Module): + def __init__(self, hidden_size): + super(Attention, self).__init__() + self.hidden_size = hidden_size + + self.attn = nn.Linear(hidden_size * 2, hidden_size) + self.v = nn.Parameter(torch.rand(hidden_size), requires_grad=True) + stdv = 1. / math.sqrt(self.v.size(0)) + self.v.data.uniform_(-stdv, stdv) + + def forward(self, hidden, encoder_outputs): + timestep = encoder_outputs.size(1) + h = hidden.expand(timestep, -1, -1).transpose(0, 1) + attn_energies = self.score(h, encoder_outputs) + return attn_energies.softmax(2) + + def score(self, hidden, encoder_outputs): + energy = torch.tanh(self.attn(torch.cat([hidden, encoder_outputs], 2))) + energy = energy.transpose(1, 2) + v = self.v.expand(encoder_outputs.size(0), -1).unsqueeze(1) + energy = torch.bmm(v, energy) + return energy + + +class Decoder(nn.Module): + def __init__(self, vocab_size, max_len, hidden_size, sos_id, eos_id, n_layers=1): + super(Decoder, self).__init__() + + self.vocab_size = vocab_size + self.max_len = max_len + self.hidden_size = hidden_size + self.sos_id = sos_id + self.eos_id = eos_id + self.n_layers = n_layers + + self.emb = nn.Embedding(vocab_size, hidden_size) + self.attention = Attention(hidden_size) + self.rnn = nn.GRU(hidden_size * 2, hidden_size, n_layers) + + self.out = nn.Linear(hidden_size, vocab_size) + + def forward_step(self, input_, last_hidden, encoder_outputs): + emb = self.emb(input_.transpose(0, 1)) + attn = self.attention(last_hidden, encoder_outputs) + context = attn.bmm(encoder_outputs).transpose(0, 1) + rnn_input = torch.cat((emb, context), dim=2) + + outputs, hidden = self.rnn(rnn_input, last_hidden) + + if outputs.requires_grad: + outputs.register_hook(lambda x: x.clamp(min=-10, max=10)) + + outputs = self.out(outputs.contiguous().squeeze(0)).log_softmax(1) + + return outputs, hidden + + def forward(self, inputs=None, encoder_hidden=None, encoder_outputs=None, + teacher_forcing_ratio=0): + inputs, batch_size, max_length = self._validate_args( + inputs, encoder_hidden, encoder_outputs, teacher_forcing_ratio) + + use_teacher_forcing = True if torch.rand(1).item() < teacher_forcing_ratio else False + + outputs = [] + + self.rnn.flatten_parameters() + + decoder_hidden = torch.zeros(1, batch_size, self.hidden_size, device=encoder_outputs.device) + + def decode(step_output): + symbols = step_output.topk(1)[1] + return symbols + + if use_teacher_forcing: + for di in range(max_length): + decoder_input = inputs[:, di].unsqueeze(1) + + decoder_output, decoder_hidden = self.forward_step( + decoder_input, decoder_hidden, encoder_outputs) + + step_output = decoder_output.squeeze(1) + outputs.append(step_output) + else: + decoder_input = inputs[:, 0].unsqueeze(1) + for di in range(max_length): + decoder_output, decoder_hidden = self.forward_step( + decoder_input, decoder_hidden, encoder_outputs + ) + + step_output = decoder_output.squeeze(1) + outputs.append(step_output) + + symbols = decode(step_output) + decoder_input = symbols + + outputs = torch.stack(outputs).permute(1, 0, 2) + + return outputs, decoder_hidden + + def _validate_args(self, inputs, encoder_hidden, encoder_outputs, teacher_forcing_ratio): + batch_size = encoder_outputs.size(0) + + if inputs is None: + assert teacher_forcing_ratio == 0 + + inputs = torch.full((batch_size, 1), self.sos_id, dtype=torch.long, device=encoder_outputs.device) + + max_length = self.max_len + else: + max_length = inputs.size(1) - 1 + + return inputs, batch_size, max_length + + +class AttentionOCR(nn.Module): + def __init__(self, img_width, img_height, nh, n_classes, max_len, SOS_token, EOS_token): + super(AttentionOCR, self).__init__() + + self.incept = MyIncept() + + f = self.incept(torch.rand(1, 3, img_height, img_width)) + + self._fh = f.size(2) + self._fw = f.size(3) + logging.info(f'Inception model feature size: fh: {self._fh}, fw: {self._fw}') + + self.onehot_x = OneHot(self._fh) + self.onehot_y = OneHot(self._fw) + self.encode_emb = nn.Linear(288 + self._fh + self._fw, nh) + self.decoder = Decoder(n_classes, max_len, nh, SOS_token, EOS_token) + + self._device = 'cpu' + + def forward(self, input_, target_seq=None, teacher_forcing_ratio=0): + device = input_.device + b, c, h, w = input_.size() + encoder_outputs = self.incept(input_) + + b, fc, fh, fw = encoder_outputs.size() + + x, y = torch.meshgrid(torch.arange(fh, device=device), torch.arange(fw, device=device)) + + h_loc = self.onehot_x(x) + w_loc = self.onehot_y(y) + + loc = torch.cat([h_loc, w_loc], dim=2).unsqueeze(0).expand(b, -1, -1, -1) + + encoder_outputs = torch.cat([encoder_outputs.permute(0, 2, 3, 1), loc], dim=3) + encoder_outputs = encoder_outputs.contiguous().view(b, -1, 288 + self._fh + self._fw) + + encoder_outputs = self.encode_emb(encoder_outputs) + + decoder_outputs, decoder_hidden = self.decoder(target_seq, encoder_outputs=encoder_outputs, + teacher_forcing_ratio=teacher_forcing_ratio) + + return decoder_outputs diff --git a/src/attention_ocr/predict.py b/src/attention_ocr/predict.py new file mode 100644 index 0000000..438bc55 --- /dev/null +++ b/src/attention_ocr/predict.py @@ -0,0 +1,67 @@ +import cv2 as cv +import torch +import numpy as np +import argparse +import sys +import os.path +import matplotlib.pyplot as plt + +from torchvision import transforms + +from model.attention_ocr import AttentionOCR +from utils.tokenizer import Tokenizer +from utils.img_util import display_images + +MODEL_PATH_FILE = './chkpoint/time_2020-06-19_18-24-50_epoch_12.pth' + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Usage: python predict.py --image=path/to/the/image/file.jpg') + parser.add_argument('--image', help='Path to image file.') + args = parser.parse_args() + + # Open the image file + if not os.path.isfile(args.image): + print("Input image file ", args.image, " doesn't exist") + sys.exit(1) + cap = cv.VideoCapture(args.image) + + hasFrame, frame = cap.read() + + + chars = list('1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ') + img_width = 160 + img_height = 60 + nh = 512 + n_chars = 7 + device = "cpu" + + tokenizer = Tokenizer(chars) + model = AttentionOCR(img_width, img_height, nh, tokenizer.n_token, + n_chars + 1, tokenizer.SOS_token, tokenizer.EOS_token).to(device=device) + + model.load_state_dict(torch.load(MODEL_PATH_FILE)) + + img_trans = transforms.Compose([ + transforms.ToPILImage() + ,transforms.Resize((img_height, img_width)) + ,transforms.Grayscale(num_output_channels=3) + ,transforms.ToTensor() + ,transforms.Normalize(mean=[0.5, 0.5, 0.5], std=(0.5, 0.5, 0.5)) + ]) + + if hasFrame: + print(f'Frame shape: {frame.shape}') + img = img_trans(frame) + print(f'tensor shape: {img.shape}') + print(f'unsqueezed tensor shape: {img.unsqueeze(0).shape}') + + model.eval() + with torch.no_grad(): + pred = model(img.unsqueeze(0)) + + pred = tokenizer.translate(pred.squeeze(0).argmax(1)) + print(f'prediction: {pred}') + + display_images(img.numpy(), 1, 3) + else: + print("Frame not found!") diff --git a/src/attention_ocr/test.py b/src/attention_ocr/test.py new file mode 100644 index 0000000..0dd8955 --- /dev/null +++ b/src/attention_ocr/test.py @@ -0,0 +1,38 @@ +import torch + +from model.attention_ocr import AttentionOCR +from utils.tokenizer import Tokenizer +from utils.dataset import SSIGALPRDataset +from utils.img_util import display_images + +MODEL_PATH_FILE = './chkpoint/time_2020-06-19_18-24-50_epoch_12.pth' +ROOT_IMG_PATH = 'ssigalpr_samples/val/' +ANNOTADED_FILE = 'ssigalpr_samples/val.csv' + +img_width = 160 +img_height = 60 +nh = 512 +chars = list('1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ') +n_chars = 7 +device = 'cpu' + +tokenizer = Tokenizer(chars) +model = AttentionOCR(img_width, img_height, nh, tokenizer.n_token, + n_chars + 1, tokenizer.SOS_token, tokenizer.EOS_token).to(device=device) + +model.load_state_dict(torch.load(MODEL_PATH_FILE)) + +dataset = SSIGALPRDataset(img_width, img_height, n_chars=n_chars, labels_path=ANNOTADED_FILE, root_img_dir=ROOT_IMG_PATH) + +img, label = dataset[0] +print(f'tensor shape: {img.shape}') +print(f'unsqueezed tensor shape: {img.unsqueeze(0).shape}') + +model.eval() +with torch.no_grad(): + pred = model(img.unsqueeze(0)) + +pred = tokenizer.translate(pred.squeeze(0).argmax(1)) +print(f'groundtruth: {tokenizer.translate(label)}\nprediction: {pred}') + +display_images(img.numpy(), 1, 3) diff --git a/src/attention_ocr/train.py b/src/attention_ocr/train.py new file mode 100644 index 0000000..848aa16 --- /dev/null +++ b/src/attention_ocr/train.py @@ -0,0 +1,146 @@ +import argparse +import random +import time +import pickle + +from tqdm import tqdm + +from PIL import Image + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader + +from torchvision import transforms + +from model.attention_ocr import AttentionOCR +from utils.dataset import SSIGALPRDataset +from utils.train_util import train_batch, eval_batch + +ROOT_TRAIN_IMG_DIR = 'F:\\dev\\ssigalpr_dataset\\test_train' +ANNOTADED_TRAIN_FILE = 'ssigalpr_samples/test_train.csv' +ROOT_VAL_IMG_DIR = 'F:\\dev\\ssigalpr_dataset\\val' +ANNOTADED_VAL_FILE = 'ssigalpr_samples/val.csv' + + +def main(inception_model='./inception_v3_google-1a9a5a14.pth', n_epoch=100, max_len=4, batch_size=32, n_works=4, + save_checkpoint_every=5, device='cuda', train_labels_path=ANNOTADED_TRAIN_FILE, train_root_img_dir=ROOT_TRAIN_IMG_DIR, + test_labels_path=ANNOTADED_VAL_FILE, test_root_img_dir=ROOT_VAL_IMG_DIR): + img_width = 160 + img_height = 60 + nh = 512 + + teacher_forcing_ratio = 0.5 + lr = 3e-4 + + ds_train = SSIGALPRDataset(img_width, img_height, n_chars=7, labels_path=train_labels_path, root_img_dir=train_root_img_dir) + ds_test = SSIGALPRDataset(img_width, img_height, n_chars=7, labels_path=test_labels_path, root_img_dir=test_root_img_dir) + + tokenizer = ds_train.tokenizer + + train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=n_works) + test_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=False, num_workers=n_works) + + model = AttentionOCR(img_width, img_height, nh, tokenizer.n_token, + max_len + 1, tokenizer.SOS_token, tokenizer.EOS_token).to(device=DEVICE) + + load_weights = torch.load(inception_model) + + names = set() + for k, w in model.incept.named_children(): + names.add(k) + + weights = {} + for k, w in load_weights.items(): + if k.split('.')[0] in names: + weights[k] = w + + model.incept.load_state_dict(weights) + + optimizer = optim.Adam(model.parameters(), lr=lr) + crit = nn.NLLLoss().cuda() + + def train_epoch(): + sum_loss_train = 0 + n_train = 0 + sum_acc = 0 + sum_sentence_acc = 0 + + for bi, batch in enumerate(tqdm(train_loader)): + x, y = batch + x = x.to(device=device) + y = y.to(device=device) + + loss, acc, sentence_acc = train_batch(x, y, model, optimizer, + crit, teacher_forcing_ratio, max_len, + tokenizer) + + sum_loss_train += loss + sum_acc += acc + sum_sentence_acc += sentence_acc + + n_train += 1 + + return sum_loss_train / n_train, sum_acc / n_train, sum_sentence_acc / n_train + + def eval_epoch(): + sum_loss_eval = 0 + n_eval = 0 + sum_acc = 0 + sum_sentence_acc = 0 + + for bi, batch in enumerate(tqdm(test_loader)): + x, y = batch + x = x.to(device=device) + y = y.to(device=device) + + loss, acc, sentence_acc = eval_batch(x, y, model, crit, max_len, tokenizer) + + sum_loss_eval += loss + sum_acc += acc + sum_sentence_acc += sentence_acc + + n_eval += 1 + + return sum_loss_eval / n_eval, sum_acc / n_eval, sum_sentence_acc / n_eval + + for epoch in range(n_epoch): + train_loss, train_acc, train_sentence_acc = train_epoch() + eval_loss, eval_acc, eval_sentence_acc = eval_epoch() + + print("Epoch %d" % epoch) + print('train_loss: %.4f, train_acc: %.4f, train_sentence: %.4f' % (train_loss, train_acc, train_sentence_acc)) + print('eval_loss: %.4f, eval_acc: %.4f, eval_sentence: %.4f' % (eval_loss, eval_acc, eval_sentence_acc)) + + if epoch % save_checkpoint_every == 0 and epoch > 0: + print('saving checkpoint...') + torch.save(model.state_dict(), './chkpoint/time_%s_epoch_%s.pth' % (time.strftime('%Y-%m-%d_%H-%M-%S'), epoch)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Usage: python train_.py --inception=\'./inception_v3_google-1a9a5a14.pth\' --e=1 --cuda') + parser.add_argument('--e', type=int, nargs='?', const=100, default=100, help='Number of epochs to train the model') + parser.add_argument('--l', type=int, nargs='?', const=7, default=7, help='Max number of characters in the image') + parser.add_argument('--c', type=int, nargs='?', const=5, default=5, help='Save model every given number of epochs (checkpoint)') + parser.add_argument('--w', type=int, nargs='?', const=4, default=4, help='Number of workers') + parser.add_argument('--cuda', action='store_true', default=False, help='Use CUDA') + parser.add_argument('--inception', help='Path to the inception model') + args = parser.parse_args() + + # TODO: 1) put all this params in a config file + # 2) load previous model and continue from there + NUM_EPOCHS = args.e if args.e is not None else 100 + MAX_LEN = args.l if args.l is not None else 7 + CHECKPOINT = args.c if args.c is not None else 5 + DEVICE = 'cuda' if args.cuda else 'cpu' + N_WORKERS = args.w if args.w is not None else 6 + INCEPTION_MODEL = args.inception if args.inception is not None else './inception_v3_google-1a9a5a14.pth' + + print(f'Device: {DEVICE} {args.cuda}\nEpochs: {NUM_EPOCHS}\nChar length: {MAX_LEN}\nCheckpoint every: {CHECKPOINT} epochs\nNumber of workers: {N_WORKERS}') + print(f'Inception model: {INCEPTION_MODEL}') + + main(inception_model=INCEPTION_MODEL, n_epoch=NUM_EPOCHS, max_len=MAX_LEN, n_works=N_WORKERS, + save_checkpoint_every=CHECKPOINT, device=DEVICE, + train_labels_path=ANNOTADED_TRAIN_FILE, train_root_img_dir=ROOT_TRAIN_IMG_DIR, + test_labels_path=ANNOTADED_VAL_FILE, test_root_img_dir=ROOT_VAL_IMG_DIR) diff --git a/src/attention_ocr/utils/__init__.py b/src/attention_ocr/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/attention_ocr/utils/dataset.py b/src/attention_ocr/utils/dataset.py new file mode 100644 index 0000000..1206e0f --- /dev/null +++ b/src/attention_ocr/utils/dataset.py @@ -0,0 +1,58 @@ +import numpy as np +import os +import pandas as pd + +from PIL import Image + +import torch +from torchvision import transforms +from torch.utils.data import Dataset + +from utils.tokenizer import Tokenizer + +class SSIGALPRDataset(Dataset): + def __init__(self, img_width, img_height, n_chars=7, chars=None, labels_path='/path/to/the/annotated/file', root_img_dir='/path/to/img/dir'): + self.n_chars = n_chars + + if chars is None: + self.chars = list('1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ') + else: + self.chars = list(chars) + + self.tokenizer = Tokenizer(self.chars) + + df = pd.read_csv(labels_path, dtype={'img_id': str}) + self.annotaded_data = df.loc[df['text'] != 'no_one'] + self.root_img_dir = root_img_dir + + self.img_trans = transforms.Compose([ + transforms.Resize((img_height, img_width)) + ,transforms.Grayscale(num_output_channels=3) + ,transforms.ToTensor() + ,transforms.Normalize(mean=[0.5, 0.5, 0.5], std=(0.5, 0.5, 0.5)) + ]) + + def __len__(self): + return self.annotaded_data.shape[0] + + def __getitem__(self, item): + annotaded_item = self.annotaded_data.iloc[item] + + img_id = annotaded_item[0] + img_path = self.root_img_dir + '/' + img_id + '.png' + img = Image.open(img_path) + + width, height = img.size + x0 = annotaded_item[1] * width + y0 = annotaded_item[2] * height + x1 = annotaded_item[3] * width + y1 = annotaded_item[4] * height + + roi = img.crop((x0, y0, x1, y1)) + + groundtruth = annotaded_item[5] + groundtruth_label = torch.full((self.n_chars + 2, ), self.tokenizer.EOS_token, dtype=torch.long) + ts = self.tokenizer.tokenize(groundtruth) + groundtruth_label[:ts.shape[0]] = torch.tensor(ts) + + return self.img_trans(roi), groundtruth_label diff --git a/src/attention_ocr/utils/img_util.py b/src/attention_ocr/utils/img_util.py new file mode 100644 index 0000000..e999c5e --- /dev/null +++ b/src/attention_ocr/utils/img_util.py @@ -0,0 +1,20 @@ +import matplotlib.pyplot as plt + +def plot_images(data, rows, cols, cmap='gray'): + if(len(data) > 0): + i = 0 + for title, image in data.items(): + plt.subplot(rows,cols,i+1),plt.imshow(image,cmap) + plt.title(title) + plt.xticks([]),plt.yticks([]) + i += 1 + plt.show() + +def display_images(img_list, row, col): + if(len(img_list) > 0): + images = {} + n = 0 + for img in img_list: + n += 1 + images[str(n)] = img + plot_images(images, row, col, cmap='gray') \ No newline at end of file diff --git a/src/attention_ocr/utils/tokenizer.py b/src/attention_ocr/utils/tokenizer.py new file mode 100644 index 0000000..9230570 --- /dev/null +++ b/src/attention_ocr/utils/tokenizer.py @@ -0,0 +1,46 @@ +import numpy as np + + +class Tokenizer: + def __init__(self, tokens): + self.tokens = tokens + self.SOS_token = 0 + self.EOS_token = 1 + self.UNKNOWN_token = len(tokens) + 2 + self.n_token = len(tokens) + 3 + + char_idx = {} + + for i, c in enumerate(tokens): + char_idx[c] = i + 2 + + self.char_idx = char_idx + + def tokenize(self, s): + label = np.zeros((len(s) + 1,), dtype=np.long) + label[0] = self.SOS_token + + for i, c in enumerate(s): + label[i + 1] = self.char_idx.get(c, self.UNKNOWN_token) + + return label + + def translate(self, ts, n=None): + ret = [] + + if n is None: + n = len(ts) + + for i in range(n): + t = ts[i] + + if t == self.SOS_token: + pass + elif t == self.EOS_token: + ret.append('-') + elif t == self.UNKNOWN_token: + ret.append('?') + else: + ret.append(self.tokens[t - 2]) + + return ''.join(ret) \ No newline at end of file diff --git a/src/attention_ocr/utils/train_util.py b/src/attention_ocr/utils/train_util.py new file mode 100644 index 0000000..a524fa2 --- /dev/null +++ b/src/attention_ocr/utils/train_util.py @@ -0,0 +1,76 @@ +import torch + + +def train_batch(input_tensor, target_tensor, model, optimizer, + criterion, teacher_forcing_ratio, max_len, + tokenizer): + model.train() + + decoder_output = model(input_tensor, target_tensor, teacher_forcing_ratio) + + loss = 0 + + optimizer.zero_grad() + + for i in range(decoder_output.size(1)): + loss += criterion(decoder_output[:, i, :].squeeze(), target_tensor[:, i + 1]) + + loss.backward() + optimizer.step() + + target_tensor = target_tensor.cpu() + decoder_output = decoder_output.cpu() + + prediction = torch.zeros_like(target_tensor) + prediction[:, 0] = tokenizer.SOS_token + for i in range(decoder_output.size(1)): + prediction[:, i + 1] = decoder_output[:, i, :].squeeze().argmax(1) + + n_right = 0 + n_right_sentence = 0 + + for i in range(prediction.size(0)): + eq = prediction[i, 1:] == target_tensor[i, 1:] + n_right += eq.sum().item() + n_right_sentence += eq.all().item() + + return loss.item() / len(decoder_output), \ + n_right / prediction.size(0) / prediction.size(1), \ + n_right_sentence / prediction.size(0) + + +def predict_batch(input_tensor, model): + model.eval() + decoder_output = model(input_tensor) + + return decoder_output + + +def eval_batch(input_tensor, target_tensor, model, criterion, max_len, tokenizer): + loss = 0 + + decoder_output = predict_batch(input_tensor, model) + + for i in range(decoder_output.size(1)): + loss += criterion(decoder_output[:, i, :].squeeze(), target_tensor[:, i + 1]) + + target_tensor = target_tensor.cpu() + decoder_output = decoder_output.cpu() + + prediction = torch.zeros_like(target_tensor) + prediction[:, 0] = tokenizer.SOS_token + + for i in range(decoder_output.size(1)): + prediction[:, i + 1] = decoder_output[:, i, :].squeeze().argmax(1) + + n_right = 0 + n_right_sentence = 0 + + for i in range(prediction.size(0)): + eq = prediction[i, 1:] == target_tensor[i, 1:] + n_right += eq.sum().item() + n_right_sentence += eq.all().item() + + return loss.item() / len(decoder_output), \ + n_right / prediction.size(0) / prediction.size(1), \ + n_right_sentence / prediction.size(0) \ No newline at end of file diff --git a/src/image_processing.py b/src/image_processing.py deleted file mode 100644 index 7eeb4d3..0000000 --- a/src/image_processing.py +++ /dev/null @@ -1,235 +0,0 @@ -import cv2 as cv -import numpy as np -import logging -import matplotlib.pyplot as plt -from skimage.color import rgb2gray -from skimage.morphology import erosion, dilation, opening, closing, black_tophat -from skimage.morphology import reconstruction -from skimage.morphology import disk, square, rectangle -from skimage.filters import threshold_li, threshold_mean, threshold_multiotsu, threshold_niblack, threshold_yen, threshold_otsu, threshold_local, rank -from skimage import exposure -from skimage import util -from skimage import data -from imutils import contours - -def equalize_histogram(image): - images = {} - img_gray = image - - if(len(img_gray.shape) > 2): - img_gray = rgb2gray(image) - images['grayscale'] = img_gray - - global_eq = exposure.equalize_hist(img_gray) - images['global'] = global_eq - - local_eq = rank.equalize(img_gray, selem=disk(30)) - images['local'] = local_eq - - plot_images(images, 2, 3, cmap='gray') - - return local_eq - -def opening_by_reconstruction(image, se): - eroded = erosion(image, se) - reconstructed = reconstruction(eroded, image) - return reconstructed - -def closing_by_reconstruction(image, se, iterations=1): - obr = opening_by_reconstruction(image, se) - - obr_inverted = util.invert(obr) - obr_inverted_eroded = erosion(obr_inverted, se) - obr_inverted_eroded_rec = reconstruction(obr_inverted_eroded, obr_inverted) - obr_inverted_eroded_rec_inverted = util.invert(obr_inverted_eroded_rec) - return obr_inverted_eroded_rec_inverted - -def square_resize(img): - """ - This function resize non square image to square one (height == width) - :param img: input image as numpy array - :return: numpy array - """ - # image after making height equal to width - squared_image = img - # Get image height and width - h = img.shape[0] - w = img.shape[1] - - # In case height superior than width - if h > w: - diff = h-w - if diff % 2 == 0: - x1 = np.zeros(shape=(h, diff//2)) - x2 = x1 - else: - x1 = np.zeros(shape=(h, diff//2)) - x2 = np.zeros(shape=(h, (diff//2)+1)) - squared_image = np.concatenate((x1, img, x2), axis=1) - - # In case height inferior than width - if h < w: - diff = w-h - if diff % 2 == 0: - x1 = np.zeros(shape=(diff//2, w)) - x2 = x1 - else: - x1 = np.zeros(shape=(diff//2, w)) - x2 = np.zeros(shape=((diff//2)+1, w)) - squared_image = np.concatenate((x1, img, x2), axis=0) - - return squared_image - -def plot_images(data, rows, cols, cmap='gray'): - if(len(data) > 0): - i = 0 - for title, image in data.items(): - #logging.debug(title) - plt.subplot(rows,cols,i+1),plt.imshow(image,cmap) - plt.title(title) - plt.xticks([]),plt.yticks([]) - i += 1 - plt.show() - -def display_images(img_list, row, col): - if(len(img_list) > 0): - images = {} - n = 0 - for img in img_list: - n += 1 - images[str(n)] = img - plot_images(images, row, col, cmap='gray') - -def draw_bounding_box(image, text_label, startPoint_x, startPoint_y, endPoint_x, endPoint_y, color=(0, 255, 0), thickness=2): - # draw rectangle - cv.rectangle(image, (startPoint_x, startPoint_y), (endPoint_x, endPoint_y), color, thickness) - - # draw the label at the top of the bounding box - labelSize, baseLine = cv.getTextSize(text_label, cv.FONT_HERSHEY_SIMPLEX, 0.5, 1) - top = max(startPoint_y, labelSize[1]) - cv.rectangle(image, (startPoint_x, top - round(1.5*labelSize[1])), (startPoint_x + round(1.5*labelSize[0]), top + baseLine), (0, 0, 255), cv.FILLED) - cv.putText(image, text_label, (startPoint_x, top), cv.FONT_HERSHEY_SIMPLEX, 0.75, (0,0,0), 2) - -def cv_skeletonize(img): - """ - Steps: - 1 - Starting off with an empty skeleton. - 2 - Computing the opening of the original image. Let’s call this open. - 3 - Substracting open from the original image. Let’s call this temp. - 4 - Eroding the original image and refining the skeleton by computing the union of the current skeleton and temp. - 5 - Repeat Steps 2–4 till the original image is completely eroded. - """ - element = cv.getStructuringElement(cv.MORPH_CROSS, (3,3)) - # Step 1: Create an empty skeleton - skel = np.zeros(img.shape, np.uint8) - while True: - #Step 2: Open the image - open_img = cv.morphologyEx(img, cv.MORPH_OPEN, element) - #Step 3: Substract open from the original image - temp = cv.subtract(img, open_img) - #Step 4: Erode the original image and refine the skeleton - eroded = cv.erode(img, element) - skel = cv.bitwise_or(skel, temp) - img = eroded.copy() - # Step 5: If there are no white pixels left ie.. the image has been completely eroded, quit the loop - if cv.countNonZero(img)==0: - break - - return skel - -def extract_contours(image, min_contours_area_ratio=0.02, max_contours_area_ratio=0.2): - mask = np.zeros(image.shape, dtype=np.uint8) - cnts = cv.findContours(image, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE) - cnts = cnts[0] if len(cnts) == 2 else cnts[1] - (cnts, _) = contours.sort_contours(cnts, method="left-to-right") - logging.debug(f'Found {len(cnts)} contours!') - roi_index = 0 - total_area = image.shape[0] * image.shape[1] - logging.debug(f'Total area ({image.shape[0]}, {image.shape[1]}): {total_area}') - contours_used_for_masking = 0 - rois = [] - for c in cnts: - x,y,w,h = cv.boundingRect(c) - roi_area = w * h - roi_area_ratio = roi_area / total_area - logging.debug(f'ROI {roi_index} area: {roi_area} - ratio: {roi_area_ratio}') - - if roi_area_ratio >= min_contours_area_ratio and roi_area_ratio <= max_contours_area_ratio: - contours_used_for_masking += 1 - roi = image[y:y+h, x:x+w].copy() - - mask[y:y+h, x:x+w] = 255 - - aux_roi = np.array(roi) - aux_roi = cv.resize(aux_roi,(28,28), interpolation = cv.INTER_AREA) - aux_roi[aux_roi != 0] = 255 - rois.append(aux_roi) - - roi_index += 1 - - logging.debug(f'Contours used for masking: {contours_used_for_masking}') - - return rois, mask - -def generate_pre_marker_image(image): - img_gray = cv.cvtColor(image, cv.COLOR_BGR2GRAY) - - structuringElement = cv.getStructuringElement(cv.MORPH_RECT, (5, 5)) - pre_marker_img = cv.morphologyEx(img_gray, cv.MORPH_BLACKHAT, structuringElement) - - ret, pre_marker_img = cv.threshold(pre_marker_img, 0, 255, cv.THRESH_BINARY+cv.THRESH_OTSU) - - return pre_marker_img - -def skeleton_marker_based_watershed_segmentation(image): - pre_marker_img = generate_pre_marker_image(image) - - skeleton = cv_skeletonize(pre_marker_img) - ret, markers = cv.connectedComponents(skeleton) - watershed_result = cv.watershed(image, markers) - - watershed_result[watershed_result == -1] = 255 - watershed_result[watershed_result != 255] = 0 - watershed_result = np.uint8(watershed_result) - - return watershed_result - -def intersection_lines_marker_based_watershed_segmentation(image): - pre_marker_img = generate_pre_marker_image(image) - - intersection_line_img = np.zeros(pre_marker_img.shape, np.uint8) - height, width = pre_marker_img.shape - - cv.line(intersection_line_img, pt1=(0, int(height/2)), pt2=(width, int(height/2)), color=(255), thickness=5) - cv.line(intersection_line_img, pt1=(0, int(height/2+height/4)), pt2=(width, int(height/2+height/4)), color=(255), thickness=5) - intersection_img = cv.bitwise_and(intersection_line_img, pre_marker_img) - - ret, markers = cv.connectedComponents(intersection_img) - watershed_result = cv.watershed(image, markers) - - watershed_result[watershed_result == -1] = 255 - watershed_result[watershed_result != 255] = 0 - watershed_result = np.uint8(watershed_result) - - return watershed_result - -def extract_chars(image): - watershed_result = skeleton_marker_based_watershed_segmentation(image) - #watershed_result = intersection_lines_marker_based_watershed_segmentation(image) - - _, mask = extract_contours(image=watershed_result, min_contours_area_ratio=0.01, max_contours_area_ratio=0.2) - - img_gray = cv.cvtColor(image, cv.COLOR_BGR2GRAY) - ret, thresh = cv.threshold(img_gray, 0, 255, cv.THRESH_BINARY_INV+cv.THRESH_OTSU) - output_img = thresh.copy() - - thresh[mask == 0] = 0 - - # we can run extract_contours again but this time on the threshold masked to get the char contours more accurate - char_contours, refined_mask = extract_contours(image=thresh, min_contours_area_ratio=0.01, max_contours_area_ratio=0.2) - - output_img[refined_mask == 0] = 0 - # now make the image properly for tesseract (white background) - output_img = util.invert(output_img) - - return char_contours, output_img, refined_mask \ No newline at end of file diff --git a/src/ocr.py b/src/ocr.py index 6e3da38..1529595 100644 --- a/src/ocr.py +++ b/src/ocr.py @@ -2,50 +2,49 @@ import logging import numpy as np -from EMNISTNet.models import EMNISTNet +from torchvision import transforms + +from attention_ocr.utils.tokenizer import Tokenizer +from attention_ocr.model.attention_ocr import AttentionOCR class OCR(): """ Optical Character Recognition (OCR) aims to recognize characters from images. """ - def __init__(self, model_filename="config/emnist_model.pt", num_classes=47, use_cuda=False, debug=False): + def __init__(self, model_filename="config/attention_ocr_model.pth", use_cuda=False, n_chars=7, threshold=0.7): + self.img_width = 160 + self.img_height = 60 + self.nh = 512 + + self.chars = list('1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ') - if(debug): - logging.getLogger().setLevel(logging.DEBUG) - else: - logging.getLogger().setLevel(logging.INFO) + self.img_trans = transforms.Compose([ + transforms.ToPILImage() + ,transforms.Resize((self.img_height, self.img_width)) + ,transforms.Grayscale(num_output_channels=3) + ,transforms.ToTensor() + ,transforms.Normalize(mean=[0.5, 0.5, 0.5], std=(0.5, 0.5, 0.5)) + ]) - self._debug = debug - # indexed classes labels - self._groundtruth = ['0','1','2','3','4','5','6','7','8','9', # 10 classes (MNIST) - 'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z', # 36 classes (custom dataset) - 'a','b','d','e','f','g','h','n','q','r','t'] # 47 classes (EMNIST bymerge) + self.device = torch.device('cuda' if torch.cuda.is_available() and use_cuda else 'cpu') + logging.info(f'Using {self.device} device.') - self._device = torch.device("cuda:0" if torch.cuda.is_available() and use_cuda else "cpu") - logging.info(f'Using {self._device} device for predictions.') + self.tokenizer = Tokenizer(self.chars) + self.model = AttentionOCR(self.img_width, self.img_height, self.nh, self.tokenizer.n_token, + n_chars + 1, self.tokenizer.SOS_token, self.tokenizer.EOS_token).to(device=self.device) - self._model = EMNISTNet(num_classes=num_classes) - self._model.load_state_dict(torch.load(model_filename, map_location=self._device)) + self.model.load_state_dict(torch.load(model_filename, map_location=self.device)) + self.model.eval() - def predict(self, inputs): + def predict(self, input_img): """ - inputs: list of 2d numpy array containing N images where pixels lies beetwen 0 and 255 + input_img: 3 channels (h,w,c) rgb image """ - inputs = [img / 255 for img in inputs] # normalize - - t = torch.tensor(inputs, dtype=torch.float32) - t.unsqueeze_(0) - t = t.permute(1,0,2,3) - logging.debug(f'Tensor for prediction: {t.shape}') - - t.to(self._device) - preds = self._model(t) - preds = preds.argmax(dim=1) - logging.debug(f'Preds shape: {preds.shape}') - - preds_indexes = list(preds.numpy().astype(int)) - preds_classes = [self._groundtruth[idx] for idx in preds_indexes] - pred = '' - return pred.join([str(s) for s in preds_classes]) - \ No newline at end of file + t = self.img_trans(input_img) + with torch.no_grad(): + pred = self.model(t.unsqueeze(0)) + + result = self.tokenizer.translate(pred.squeeze(0).argmax(1)) + return result + \ No newline at end of file diff --git a/src/test_image_processing.py b/src/test_image_processing.py deleted file mode 100644 index 13c1cdc..0000000 --- a/src/test_image_processing.py +++ /dev/null @@ -1,299 +0,0 @@ -import cv2 as cv -import numpy as np -import argparse -import sys -import os.path -import logging -import matplotlib.pyplot as plt -from skimage.color import rgb2gray -from skimage.morphology import erosion, dilation, opening, closing, black_tophat, white_tophat -from skimage.morphology import reconstruction -from skimage.morphology import disk, square, rectangle -from skimage.filters import threshold_li, threshold_mean, threshold_multiotsu, threshold_niblack, threshold_yen, threshold_otsu, threshold_local, threshold_sauvola, threshold_niblack, rank -from skimage import exposure -from skimage import util -from skimage import data -from skimage import feature -from imutils import contours -from image_processing import opening_by_reconstruction, closing_by_reconstruction, display_images, plot_images, cv_skeletonize, extract_contours, skeleton_marker_based_watershed_segmentation, intersection_lines_marker_based_watershed_segmentation - - -def test_threshold_methods(img): - images = {} - img_gray = rgb2gray(img) - images['grayscale'] = img_gray - images['threshold_local_5'] = threshold_local(img_gray, block_size=5) - images['threshold_local_11'] = threshold_local(img_gray, block_size=11) - th = threshold_multiotsu(img_gray) - images['threshold_multiotsu'] = np.digitize(img_gray, bins=th) - th = threshold_otsu(img_gray) - images['threshold_otsu'] = img_gray >= th - th = threshold_li(img_gray) - images['threshold_li'] = img_gray >= th - th = threshold_yen(img_gray) - images['threshold_yen'] = img_gray >= th - th = threshold_mean(img_gray) - images['threshold_mean'] = img_gray > th - th = threshold_niblack(img_gray, window_size=25, k=0.8) - images['thresh_niblack'] = img_gray > th - th = threshold_sauvola(img_gray, window_size=25) - images['threshold_sauvola'] = img_gray > th - - plot_images(images, 4, 4, cmap='gray') - -def test_morphological_methods(image): - images = {} - img_gray = rgb2gray(image) - images['grayscale'] = img_gray - images['grayscale 1'] = img_gray - images['grayscale 2'] = img_gray - images['grayscale 3'] = img_gray - images['grayscale 4'] = img_gray - # openings - images['opening 3x3'] = opening(img_gray, square(3)) - images['opening 5x5'] = opening(img_gray, square(5)) - images['opening 7x7'] = opening(img_gray, square(7)) - images['opening 9x9'] = opening(img_gray, square(9)) - images['opening 11x11'] = opening(img_gray, square(11)) - # closings - images['closing 3x3'] = closing(img_gray, square(3)) - images['closing 5x5'] = closing(img_gray, square(5)) - images['closing 7x7'] = closing(img_gray, square(7)) - images['closing 9x9'] = closing(img_gray, square(9)) - images['closing 11x11'] = closing(img_gray, square(11)) - # openings by reconstruction - images['opening_by_reconstruction 3x3'] = opening_by_reconstruction(img_gray, square(3)) - images['opening_by_reconstruction 5x5'] = opening_by_reconstruction(img_gray, square(5)) - images['opening_by_reconstruction 7x7'] = opening_by_reconstruction(img_gray, square(7)) - images['opening_by_reconstruction 9x9'] = opening_by_reconstruction(img_gray, square(9)) - images['opening_by_reconstruction 11x11'] = opening_by_reconstruction(img_gray, square(11)) - # closings by reconstruction - images['closing_by_reconstruction 3x3'] = closing_by_reconstruction(img_gray, square(3)) - images['closing_by_reconstruction 5x5'] = closing_by_reconstruction(img_gray, square(5)) - images['closing_by_reconstruction 7x7'] = closing_by_reconstruction(img_gray, square(7)) - images['closing_by_reconstruction 9x9'] = closing_by_reconstruction(img_gray, square(9)) - images['closing_by_reconstruction 11x11'] = closing_by_reconstruction(img_gray, square(11)) - # erosion - images['erosion 3x3'] = erosion(img_gray, square(3)) - images['erosion 5x5'] = erosion(img_gray, square(5)) - images['erosion 7x7'] = erosion(img_gray, square(7)) - images['erosion 9x9'] = erosion(img_gray, square(9)) - images['erosion 11x11'] = erosion(img_gray, square(11)) - # dilation - images['dilation 3x3'] = dilation(img_gray, square(3)) - images['dilation 5x5'] = dilation(img_gray, square(5)) - images['dilation 7x7'] = dilation(img_gray, square(7)) - images['dilation 9x9'] = dilation(img_gray, square(9)) - images['dilation 11x11'] = dilation(img_gray, square(11)) - # black tophat - images['black_tophat 3x3'] = black_tophat(img_gray, square(3)) - images['black_tophat 5x5'] = black_tophat(img_gray, square(5)) - images['black_tophat 7x7'] = black_tophat(img_gray, square(7)) - images['black_tophat 9x9'] = black_tophat(img_gray, square(9)) - images['black_tophat 11x11'] = black_tophat(img_gray, square(11)) - # white tophat - images['white_tophat 3x3'] = white_tophat(img_gray, square(3)) - images['white_tophat 5x5'] = white_tophat(img_gray, square(5)) - images['white_tophat 7x7'] = white_tophat(img_gray, square(7)) - images['white_tophat 9x9'] = white_tophat(img_gray, square(9)) - images['white_tophat 11x11'] = white_tophat(img_gray, square(11)) - plot_images(images, 9, 5, cmap="gray") - -def test_histogram_equalization_methods(image): - images = {} - - #img_gray = rgb2gray(image) - img_gray = cv.cvtColor(image, cv.COLOR_BGR2GRAY) - images['grayscale'] = img_gray - - global_eq = exposure.equalize_hist(img_gray) - images['global'] = global_eq - - selem = disk(30) - local_eq = rank.equalize(img_gray, selem=selem) - images['local'] = local_eq - - # Adaptive Equalization - img_adapteq = exposure.equalize_adapthist(img_gray, clip_limit=0.03) - images['adaptative'] = img_adapteq - - clahe = cv.createCLAHE(clipLimit=4.0, tileGridSize=(8,8)) - clahe_img = clahe.apply(img_gray) - images['clahe'] = clahe_img - - plot_images(images, 3, 3, cmap='gray') - -def test_skeleton_marker_based_watershed_segmentation(image): - images = {} - img_gray = cv.cvtColor(image, cv.COLOR_BGR2GRAY) - images['gray'] = img_gray.copy() - ret, thresh = cv.threshold(img_gray, 0, 255, cv.THRESH_BINARY_INV+cv.THRESH_OTSU) - images['threshold inv'] = thresh.copy() - output_img = thresh.copy() - - structuringElement = cv.getStructuringElement(cv.MORPH_RECT, (5, 5)) - pre_marker_img = cv.morphologyEx(img_gray, cv.MORPH_BLACKHAT, structuringElement) - images['black hat'] = pre_marker_img.copy() - - ret, pre_marker_img = cv.threshold(pre_marker_img, 0, 255, cv.THRESH_BINARY+cv.THRESH_OTSU) - images['pre marker threshold'] = pre_marker_img.copy() - - skeleton = cv_skeletonize(pre_marker_img) - images['skeleton'] = skeleton.copy() - ret, markers = cv.connectedComponents(skeleton) - images['markers'] = markers.copy() - watershed_result = cv.watershed(image, markers) - images['watershed_result'] = watershed_result.copy() - - watershed_result[watershed_result == -1] = 255 - watershed_result[watershed_result != 255] = 0 - watershed_result = np.uint8(watershed_result) - images['final watershed_result'] = watershed_result.copy() - - _, mask = extract_contours(image=watershed_result, min_contours_area_ratio=0.01, max_contours_area_ratio=0.2) - images['mask'] = mask.copy() - - thresh[mask == 0] = 0 - images['thresh masked'] = thresh.copy() - - # we can run extract_contours again but this time on the threshold masked to get the char contours more accurate - char_contours, refined_mask = extract_contours(image=thresh, min_contours_area_ratio=0.01, max_contours_area_ratio=0.2) - images['refined_mask'] = refined_mask.copy() - - output_img[refined_mask == 0] = 0 - images['thresh refine masked'] = output_img.copy() - output_img = util.invert(output_img) - images['final result'] = output_img.copy() - - plot_images(images, 4, 4, cmap='gray') - -def test_intersection_lines_marker_based_watershed_segmentation(image): - images = {} - img_gray = cv.cvtColor(image, cv.COLOR_BGR2GRAY) - images['gray'] = img_gray.copy() - ret, thresh = cv.threshold(img_gray, 0, 255, cv.THRESH_BINARY_INV+cv.THRESH_OTSU) - images['threshold inv'] = thresh.copy() - output_img = thresh.copy() - - structuringElement = cv.getStructuringElement(cv.MORPH_RECT, (5, 5)) - pre_marker_img = cv.morphologyEx(img_gray, cv.MORPH_BLACKHAT, structuringElement) - images['black hat'] = pre_marker_img.copy() - - ret, pre_marker_img = cv.threshold(pre_marker_img, 0, 255, cv.THRESH_BINARY+cv.THRESH_OTSU) - images['pre marker threshold'] = pre_marker_img.copy() - - intersection_line_img = np.zeros(pre_marker_img.shape, np.uint8) - height, width = pre_marker_img.shape - - cv.line(intersection_line_img, pt1=(0, int(height/2)), pt2=(width, int(height/2)), color=(255), thickness=5) - cv.line(intersection_line_img, pt1=(0, int(height/2+height/4)), pt2=(width, int(height/2+height/4)), color=(255), thickness=5) - images['lines'] = intersection_line_img.copy() - intersection_img = cv.bitwise_and(intersection_line_img, pre_marker_img) - images['lines and pre marker intersection'] = intersection_img.copy() - - ret, markers = cv.connectedComponents(intersection_img) - images['markers'] = markers.copy() - watershed_result = cv.watershed(image, markers) - images['watershed_result'] = watershed_result.copy() - - watershed_result[watershed_result == -1] = 255 - watershed_result[watershed_result != 255] = 0 - watershed_result = np.uint8(watershed_result) - images['final watershed_result'] = watershed_result.copy() - - _, mask = extract_contours(image=watershed_result, min_contours_area_ratio=0.01, max_contours_area_ratio=0.2) - images['mask'] = mask.copy() - - thresh[mask == 0] = 0 - images['thresh masked'] = thresh.copy() - - # we can run extract_contours again but this time on the threshold masked to get the char contours more accurate - char_contours, refined_mask = extract_contours(image=thresh, min_contours_area_ratio=0.01, max_contours_area_ratio=0.2) - images['refined_mask'] = refined_mask.copy() - - output_img[refined_mask == 0] = 0 - images['thresh refine masked'] = output_img.copy() - output_img = util.invert(output_img) - images['final result'] = output_img.copy() - - plot_images(images, 4, 4, cmap='gray') - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Testing Image Processing Algorithms.') - parser.add_argument('--image', help='Path to image file.') - args = parser.parse_args() - - logging.getLogger().setLevel(logging.DEBUG) - - # Open the image file - if not os.path.isfile(args.image): - logging.debug("Input image file ", args.image, " doesn't exist") - sys.exit(1) - cap = cv.VideoCapture(args.image) - - hasFrame, frame = cap.read() - - if hasFrame: - #test_threshold_methods(image) - #test_histogram_equalization_methods(image) - #test_morphological_methods(frame) - test_skeleton_marker_based_watershed_segmentation(frame) - test_intersection_lines_marker_based_watershed_segmentation(frame) - - - # comparing both methods - images = {} - image = frame - img_gray = cv.cvtColor(image, cv.COLOR_BGR2GRAY) - images['skel gray'] = img_gray.copy() - ret, thresh = cv.threshold(img_gray, 0, 255, cv.THRESH_BINARY_INV+cv.THRESH_OTSU) - images['skel threshold'] = thresh.copy() - output_img = thresh.copy() - - # skeleton method - watershed_result = skeleton_marker_based_watershed_segmentation(image) - images['skel watershed'] = watershed_result.copy() - - char_contours, mask = extract_contours(image=watershed_result, min_contours_area_ratio=0.01, max_contours_area_ratio=0.2) - thresh[mask == 0] = 0 - images['skel mask 1'] = mask.copy() - images['skel threshold masked 1'] = thresh.copy() - - # we can run extract_contours again but this time on the threshold masked to get the char contours more accurate - char_contours, mask2 = extract_contours(image=thresh, min_contours_area_ratio=0.01, max_contours_area_ratio=0.2) - images['skel mask 2'] = mask2.copy() - output_img[mask2 == 0] = 0 - images['skel threshold masked 2'] = output_img.copy() - - output_img = util.invert(output_img) - images['skel output'] = output_img.copy() - - # intersection lines method - img_gray = cv.cvtColor(image, cv.COLOR_BGR2GRAY) - images['intersec gray'] = img_gray.copy() - ret, thresh = cv.threshold(img_gray, 0, 255, cv.THRESH_BINARY_INV+cv.THRESH_OTSU) - images['intersec threshold'] = thresh.copy() - output_img = thresh.copy() - - watershed_result = intersection_lines_marker_based_watershed_segmentation(image) - images['intersec watershed'] = watershed_result.copy() - - char_contours, mask = extract_contours(image=watershed_result, min_contours_area_ratio=0.01, max_contours_area_ratio=0.2) - thresh[mask == 0] = 0 - images['intersec mask 1'] = mask.copy() - images['intersec threshold masked 1'] = thresh.copy() - - # we can run extract_contours again but this time on the threshold masked to get the char contours more accurate - char_contours, mask2 = extract_contours(image=thresh, min_contours_area_ratio=0.01, max_contours_area_ratio=0.2) - images['intersec mask 2'] = mask2.copy() - output_img[mask2 == 0] = 0 - images['intersec threshold masked 2'] = output_img.copy() - - output_img = util.invert(output_img) - images['intersec output'] = output_img.copy() - - plot_images(images, 2, 8, cmap='gray') - - - else: - logging.debug("Frame not found!") \ No newline at end of file diff --git a/src/test_ocr.py b/src/test_ocr.py index 9b21d09..b738abb 100644 --- a/src/test_ocr.py +++ b/src/test_ocr.py @@ -6,33 +6,52 @@ import logging import matplotlib.pyplot as plt -from image_preprocessing import display_images, extract_chars from ocr import OCR +def plot_images(data, rows, cols, cmap='gray'): + if(len(data) > 0): + i = 0 + for title, image in data.items(): + #logging.debug(title) + plt.subplot(rows,cols,i+1),plt.imshow(image,cmap) + plt.title(title) + plt.xticks([]),plt.yticks([]) + i += 1 + plt.show() + +def display_images(img_list, row, col): + if(len(img_list) > 0): + images = {} + n = 0 + for img in img_list: + n += 1 + images[str(n)] = img + plot_images(images, row, col, cmap='gray') + if __name__ == '__main__': parser = argparse.ArgumentParser(description='Testing OCR.') parser.add_argument('--image', help='Path to image file.') args = parser.parse_args() - logging.getLogger().setLevel(logging.DEBUG) + logging.getLogger().setLevel(logging.INFO) # Open the image file if not os.path.isfile(args.image): - logging.debug("Input image file ", args.image, " doesn't exist") + logging.error("Input image file ", args.image, " doesn't exist") sys.exit(1) cap = cv.VideoCapture(args.image) hasFrame, frame = cap.read() if hasFrame: - img, characteres = extract_chars(frame, prefix_label='test_ocr', min_countours_area_ration=0.01, debug=True) + images = {} + images['frame'] = frame - #characteres = [img / 255 for img in characteres] - #display_images(characteres, 3, 3) + ocr = OCR(model_filename="../config/attention_ocr_model.pth", use_cuda=False, threshold=0.7) + pred = ocr.predict(frame) + logging.info(f'Prediction: {pred}') - ocr = OCR(model_filename="../config/emnist_model.pt", use_cuda=False, debug=True) - pred = ocr.predict(characteres) - logging.info(f'\nPrediction: {pred}') + plot_images(images, 1, 3, cmap='gray') else: logging.debug("Frame not found!") \ No newline at end of file