Skip to content

crnn模型训练 ,训练集和测试集都要是lmdb格式吗 #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 42 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
2b73c5a
Update README.md
Aurora11111 Aug 10, 2018
7714a9d
Update README.md
Aurora11111 Aug 10, 2018
fcfbc41
update
Aurora11111 Aug 22, 2018
4e4ddcb
Update README.md
Aurora11111 Aug 28, 2018
41fc98a
Update README.md
Aurora11111 Aug 28, 2018
c3edbcb
Update README.md
Aurora11111 Aug 28, 2018
f8f9b00
Update README.md
Aurora11111 Aug 28, 2018
0c271da
add code to convert image data to lmdb data
Aurora11111 Sep 11, 2018
5a92043
Update getLmdb.py
Aurora11111 Sep 11, 2018
6cc617c
Update datasets.py
Aurora11111 Sep 11, 2018
b0426a3
Update utils.py
Aurora11111 Sep 11, 2018
49fae2d
Update crnn_main.py
Aurora11111 Sep 11, 2018
1b3616e
RUN IN PYTHON2.X
Aurora11111 Sep 11, 2018
c8053e5
Update README.md
Aurora11111 Sep 11, 2018
b2c1537
Updata READ.md
Aurora11111 Sep 14, 2018
d6f5493
Update README.md
Aurora11111 Sep 14, 2018
6392908
Update README.md
Aurora11111 Sep 14, 2018
d6c2c92
Update README.md
Aurora11111 Sep 14, 2018
ca448de
Update README.md
Aurora11111 Sep 14, 2018
57c3ccb
Update README.md
Aurora11111 Sep 14, 2018
2eea3cd
Update README.md
Aurora11111 Sep 14, 2018
fe0e49c
Update README.md
Aurora11111 Sep 14, 2018
63bdffd
Update README.md
Aurora11111 Sep 14, 2018
876e1e4
Update README.md
Aurora11111 Sep 14, 2018
2e8f6ea
Update README.md
Aurora11111 Sep 14, 2018
d1d7813
Update README.md
Aurora11111 Sep 14, 2018
720bdfd
Update README.md
Aurora11111 Sep 14, 2018
7f7cc55
Update crnn_main.py
Aurora11111 Sep 18, 2018
8b39c2f
Update demo.py
Aurora11111 Sep 18, 2018
c228de9
Update README.md
Aurora11111 Sep 19, 2018
82bc2df
Update README.md
Aurora11111 Sep 19, 2018
1db6083
Update crnn_main.py
Aurora11111 Oct 11, 2018
fa51bb0
Update utils.py
Aurora11111 Oct 11, 2018
18873d7
Update crnn_main.py
Aurora11111 Oct 11, 2018
16be663
Add files via upload
Aurora11111 Oct 11, 2018
2241606
new nclass is different from old model training
Aurora11111 Oct 11, 2018
ef9ab21
Update README.md
Aurora11111 Oct 11, 2018
342d859
Update README.md
Aurora11111 Oct 11, 2018
4485cd5
Update crnn_main.py
Aurora11111 Oct 17, 2018
d9ad769
Update README.md
Aurora11111 Oct 22, 2018
d249458
Add files via upload
Aurora11111 Oct 22, 2018
b240c76
Update README.md
Aurora11111 Nov 9, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 85 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,92 @@
Convolutional Recurrent Neural Network
CRNN TRAIN
======================================

This software implements the Convolutional Recurrent Neural Network (CRNN) in pytorch.
Origin software could be found in [crnn](https://github.com/bgshih/crnn)


Envrionment
--------
python 3.6
pytorch 4.0
opencv2.4 + pytorch + lmdb +wrap_ctc
* [warp_ctc_pytorch](https://github.com/SeanNaren/warp-ctc/tree/pytorch_bindings/pytorch_binding)

ATTENTION!

getLmdb.py must run in python2.x


Issue when install warp_ctc_pytorch
----------
* [ 11%] Building NVCC (Device) object CMakeFiles/warpctc.dir/src/warpctc_generated_reduce.cu.o
sh: cicc: command not found
CMake Error at warpctc_generated_reduce.cu.o.cmake:279 (message):
Error generating file
/home/rice/warp-ctc/build/CMakeFiles/warpctc.dir/src/./warpctc_generated_reduce.cu.o
make[2]: *** [CMakeFiles/warpctc.dir/build.make:256: CMakeFiles/warpctc.dir/src/warpctc_generated_reduce.cu.o] Error 1
make[1]: *** [CMakeFiles/Makefile2:104: CMakeFiles/warpctc.dir/all] Error 2
make: *** [Makefile:130: all] Error 2 you should reinstall your cuda, and make sure it install completely
* THCudaMallco error https://github.com/baidu-research/warp-ctc/pull/71/files
* https://github.com/Xtra-Computing/thundersvm/issues/54#issuecomment-416413155
* ![my_error_image](./data/error.png)
ln -s /opt/cuda/include/* /home/rice/anaconda3/lib/python3.6/site-packages/torch/utils/ffi/../../lib/include/THC/

Train a new model
-----------------
Construct dataset following origin guide. For training with variable length, please sort the image according to the text length. reference:https://github.com/Aurora11111/TextRecognitionDataGenerator

1. 数据预处理

运行`/contrib/crnn/tool/getLmdb.py`

# 生成的lmdb输出路径
outputPath = '/run/media/rice/DATA/chinese1/lmdb'
# 图片及对应的label
imgdata = open("/run/media/rice/DATA/chinese1/labels.txt")

2. 训练模型

运行`/contrib/crnn/crnn_main.py`

python crnn_main.py [--param val]
--trainroot 训练集路径
--valroot 验证集路径
--workers CPU工作核数, default=4
--batchSize 设置batchSize大小, default=64
--imgH 图片高度, default=32
--imgW 图片宽度,default =280(所用训练图片均为280*32)
--nh LSTM隐藏层数, default=256
--niter 训练回合数, default=25
--lr 学习率, default=0.00005
--cuda 使用GPU, action='store_true'
--ngpu 使用GPU的个数, default=1
--crnn 选择预训练模型
--alphabet 设置分类
--experiment 模型保存目录
--displayInterval 设置多少次迭代显示一次, default=1000
--n_test_disp 每次验证显示的个数, default=10
--valInterval 设置多少次迭代验证一次, default=1000
--saveInterval 设置多少次迭代保存一次模型, default=1000
--adam 使用adma优化器, default='True'
--adadelta 使用adadelta优化器, action='store_true'
--keep_ratio 设置图片保持横纵比缩放, action='store_true'
--random_sample 是否使用随机采样器对数据集进行采样, action='store_true'

示例:python /contrib/crnn/crnn_main.py --tainroot [训练集路径] --valroot [验证集路径] --nh 128 --cuda --crnn [预训练模型路径]

修改`/contrib/crnn/keys.py`中`alphabet = '012346789'`增加或者减少类别

3. 注意事项
训练和预测采用的类别数和LSTM隐藏层数需保持一致


Train a new model( new nclass is dfferent from old nclass)
-----------------
when you nclass is diferent from old ones, you can use this to finetune:
python finetune.py


Run demo
--------
A demo program can be found in ``src/demo.py``. Before running the demo, download a pretrained model
Expand All @@ -16,18 +99,8 @@ Put the downloaded model file ``crnn.pth`` into directory ``data/``. Then launch
The demo reads an example image and recognizes its text content.

Example image:
![Example Image](./data/demo.png)
![my_example_image](./data/demo.png)

Expected output:
loading pretrained model from ./data/crnn.pth
a-----v--a-i-l-a-bb-l-ee-- => available

Dependence
----------
* [warp_ctc_pytorch](https://github.com/SeanNaren/warp-ctc/tree/pytorch_bindings/pytorch_binding)
* lmdb

Train a new model
-----------------
1. Construct dataset following origin guide. For training with variable length, please sort the image according to the text length.
2. ``python crnn_main.py [--param val]``. Explore ``crnn_main.py`` for details.
135 changes: 92 additions & 43 deletions crnn_main.py

Large diffs are not rendered by default.

Binary file added data/error.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 6 additions & 5 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sys
from PIL import Image
import numpy as np

import io

class lmdbDataset(Dataset):

Expand All @@ -29,7 +29,7 @@ def __init__(self, root=None, transform=None, target_transform=None):
sys.exit(0)

with self.env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'))
nSamples = int(txn.get('num-samples'.encode()))
self.nSamples = nSamples

self.transform = transform
Expand All @@ -43,13 +43,14 @@ def __getitem__(self, index):
index += 1
with self.env.begin(write=False) as txn:
img_key = 'image-%09d' % index
imgbuf = txn.get(img_key)

imgbuf = txn.get(img_key.encode())
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)

try:
img = Image.open(buf).convert('L')
img = Image.open(io.BytesIO(buf.read())).convert('L')
except IOError:
print('Corrupted image for %d' % index)
return self[index + 1]
Expand All @@ -58,7 +59,7 @@ def __getitem__(self, index):
img = self.transform(img)

label_key = 'label-%09d' % index
label = str(txn.get(label_key))
label = str(txn.get(label_key.encode()).decode())

if self.target_transform is not None:
label = self.target_transform(label)
Expand Down
11 changes: 6 additions & 5 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import models.crnn as crnn


model_path = './data/crnn.pth'
img_path = './data/demo.png'
alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'

model = crnn.CRNN(32, 1, 37, 256)
model_path = './expr/netCRNN_24_6200.pth'
img_path = './data/732.jpg'
#alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'
alphabet = '0123456789'
model = crnn.CRNN(32, 1, 11, 256)
if torch.cuda.is_available():
model = model.cuda()
model = torch.nn.DataParallel(model, device_ids=range(1))
print('loading pretrained model from %s' % model_path)
model.load_state_dict(torch.load(model_path))

Expand Down
295 changes: 295 additions & 0 deletions finetune.py

Large diffs are not rendered by default.

101 changes: 101 additions & 0 deletions getLmdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-
import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
import glob

def checkImageIsValid(imageBin):
if imageBin is None:
return False
imageBuf = np.fromstring(imageBin, dtype=np.uint8)
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
if img is None:
return False
imgH, imgW = img.shape[0], img.shape[1]
if imgH * imgW == 0:
return False
return True


def writeCache(env, cache):
with env.begin(write=True) as txn:
for k, v in cache.iteritems():
txn.put(k, v)


def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
"""
Create LMDB dataset for CRNN training.
ARGS:
outputPath : LMDB output path
imagePathList : list of image path
labelList : list of corresponding groundtruth texts
lexiconList : (optional) list of lexicon lists
checkValid : if true, check the validity of every image
"""
assert (len(imagePathList) == len(labelList))
nSamples = len(imagePathList)
print('...................')
# map_size=1099511627776 定义最大空间是1TB
env = lmdb.open(outputPath, map_size=1099511627776)

cache = {}
cnt = 1
for i in range(nSamples):
imagePath = imagePathList[i]
label = labelList[i]
if not os.path.exists(imagePath):
print('%s does not exist' % imagePath)
continue
with open(imagePath, 'r') as f:
imageBin = f.read()
if checkValid:
if not checkImageIsValid(imageBin):
print('%s is not a valid image' % imagePath)
continue

imageKey = 'image-%09d' % cnt
labelKey = 'label-%09d' % cnt
cache[imageKey] = imageBin
cache[labelKey] = label

if lexiconList:
lexiconKey = 'lexicon-%09d' % cnt
cache[lexiconKey] = ' '.join(lexiconList[i])
if cnt % 1000 == 0:
writeCache(env, cache)
cache = {}
print('Written %d / %d' % (cnt, nSamples))
cnt += 1
nSamples = cnt - 1
cache['num-samples'] = str(nSamples)
writeCache(env, cache)
print('Created dataset with %d samples' % nSamples)


def read_text(path):
with open(path) as f:
text = f.read()
text = text.strip()

return text


if __name__ == '__main__':

outputPath = '/run/media/rice/DATA/lmdb'
imgdata = open("/run/media/rice/DATA/labellist1.txt")
imagePathList = []
imgLabelLists = []
for line in list(imgdata):
label = line.split()[1]
image = line.split()[0]
imgLabelLists.append(label)
imagePathList.append('/run/media/rice/DATA/datasets2/' + image+".jpg")

print len(imagePathList)
print len(imgLabelLists)
createDataset(outputPath, imagePathList, imgLabelLists, lexiconList=None, checkValid=True)


38 changes: 18 additions & 20 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import torch.nn as nn
from torch.autograd import Variable
import collections
import chardet
import numpy as np
import sys


class strLabelConverter(object):
"""Convert between str and label.

NOTE:
Insert `blank` to the alphabet for CTC.

Args:
alphabet (str): set of the possible characters.
ignore_case (bool, default=True): whether or not to ignore all of the case.
Expand All @@ -21,7 +22,7 @@ class strLabelConverter(object):
def __init__(self, alphabet, ignore_case=True):
self._ignore_case = ignore_case
if self._ignore_case:
alphabet = alphabet.lower()
alphabet = alphabet
self.alphabet = alphabet + '-' # for `-1` index

self.dict = {}
Expand All @@ -31,42 +32,38 @@ def __init__(self, alphabet, ignore_case=True):

def encode(self, text):
"""Support batch or single str.

Args:
text (str or list of str): texts to convert.

Returns:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
"""
if isinstance(text, str):
text = [
self.dict[char.lower() if self._ignore_case else char]
for char in text
]
length = [len(text)]
elif isinstance(text, collections.Iterable):
length = [len(s) for s in text]
text = ''.join(text)
text, _ = self.encode(text)
length = []
result = []

for item in text:
item = item.decode('utf-8', 'strict')
length.append(len(item))
for char in item:
index = self.dict[char]
result.append(index)
text = result
return (torch.IntTensor(text), torch.IntTensor(length))

def decode(self, t, length, raw=False):
"""Decode encoded texts back into strs.

Args:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.

Raises:
AssertionError: when the texts and its length does not match.

Returns:
text (str or list of str): texts to convert.
"""
if length.numel() == 1:
length = length[0]
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
length)
if raw:
return ''.join([self.alphabet[i - 1] for i in t])
else:
Expand All @@ -77,7 +74,8 @@ def decode(self, t, length, raw=False):
return ''.join(char_list)
else:
# batch mode
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
t.numel(), length.sum())
texts = []
index = 0
for i in range(length.numel()):
Expand Down