crnn pytorch 訓練、測試


1.倉庫地址

https://github.com/meijieru/crnn.pytorch
原版用lua實現的:https://github.com/bgshih/crnn
需要用到的warp_ctc_pytorch: https://github.com/SeanNaren/warp-ctc

2.環境安裝

普通的環境都可以吧,我是cuda10.0,torch1.2.0 python3.6. 其他環境也應該可以。
然后庫缺少什么就安裝什么 pip install ***

warp-CTC需要編譯

git clone https://github.com/SeanNaren/warp-ctc.git
cd warp-ctc
mkdir build; cd build
cmake ..
make
cd ../pytorch_binding
python setup.py install

我就是這么沒有報錯就ok
測試是否安裝成功就進入python
import warpctc_pytorch
沒有報錯就說明成功

3.數據准備,lmdb制作


需要這么放置,圖片和文本放在一個文件夾,文本名和圖片名字一樣,文本里面內容是圖片上文字。
運行https://github.com/wuzuowuyou/crnn_pytorch/blob/master/myfile/create_lmdb.py腳本
這里注意需要python2運行。我用Python3運行各種報錯什么編碼問題,用py2跑一點報錯都沒有,python2也需要裝lmdb,(pip2 install lmdb)
跑成功會自動生成這兩個東東
./lmdb/data.mdb
./lmdb/lock.mdb
把lmdb文件夾放在data目錄下面。

4. 訓練

python train.py --adadelta --trainRoot ./data/lmdb/ --valRoot ./data/lmdb/ --cuda

這里注意一下,如果有大小寫,需要改下字典表
train.py line32
parser.add_argument('--alphabet', type=str, default='0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ')

5.報錯解決

各種報錯啊
5.1 trainRoot,valRoot需要改下大小寫
5.2 TypeError: Won't implicitly convert Unicode to bytes; use .encode()
按照錯誤提示加上encode
txn.get('num-samples'.encode())
label_byte = txn.get(label_key.encode())
imgbuf = txn.get(img_key.encode())
5.3
text, _ = self.encode(text)
File "/home/crnn.pytorch/utils.py", line 45, in encode
for char in text
File "/home/crnn.pytorch/utils.py", line 45, in
for char in text
KeyError: 'b'
解決方案:
dataset.py line 61
label = str(txn.get(label_key)) ->
label_byte=txn.get(label_key.encode())
label = label_byte.decode()

5.4 raise ValueError('sampler option is mutually exclusive with '
ValueError: sampler option is mutually exclusive with shuffle
大意就是sampler和shuffle互斥
我加了 and 0 不用sample
if not opt.random_sample and 0:

5.5 在驗證的時候還報錯,
Start val
Traceback (most recent call last):
File "/data_2/project_2021/crnn/crnn.pytorch-master/train.py", line 219, in
val(crnn, test_dataset, criterion)
File "/data_2/project_2021/crnn/crnn.pytorch-master/train.py", line 168, in val
preds = preds.squeeze(2)
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
我不驗證,加and 0:
if i % opt.valInterval == 0 and 0:
val(crnn, test_dataset, criterion)

錯誤解決了,然后就可以訓練,打印如下:

  (relu6): ReLU(inplace=True)
  )
  (rnn): Sequential(
    (0): BidirectionalLSTM(
      (rnn): LSTM(512, 256, bidirectional=True)
      (embedding): Linear(in_features=512, out_features=256, bias=True)
    )
    (1): BidirectionalLSTM(
      (rnn): LSTM(256, 256, bidirectional=True)
      (embedding): Linear(in_features=512, out_features=63, bias=True)
    )
  )
)
[0/100000000][1/9] Loss: 8.430408
[0/100000000][2/9] Loss: 20.137066
[0/100000000][3/9] Loss: 25.239346
[0/100000000][4/9] Loss: 21.249365
[0/100000000][5/9] Loss: 20.604660
[0/100000000][6/9] Loss: 14.782236

6.測試 demo.py

需要改下這里,和訓練的時候一致
model = crnn.CRNN(32, 1, 37, 256)

報錯
File "/data_2/project_2021/crnn/crnn.pytorch-master/demo_show.py", line 42, in
model.load_state_dict(torch.load(model_path))
File "/data_1/Yang/software_install/Anaconda1105/envs/CenterNet_1.0_3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 845, in load_state_dict
self. class. name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for CRNN:
Missing key(s) in state_dict: "cnn.conv0.weight", "cnn.conv0.bias", "cnn.conv1.weight", "cnn.conv1.bias", "cnn.conv2.weight", "cnn.conv2.bias", "cnn.batchnorm2.weight", "cnn.batchnorm2.bias", "cnn.batchnorm2.running_mean", "cnn.batchnorm2.running_var", "cnn.conv3.weight", "cnn.conv3.bias", "cnn.conv4.weight", "cnn.conv4.bias", "cnn.batchnorm4.weight", "cnn.batchnorm4.bias", "cnn.batchnorm4.running_mean", "cnn.batchnorm4.running_var", "cnn.conv5.weight", "cnn.conv5.bias", "cnn.conv6.weight", "cnn.conv6.bias", "cnn.batchnorm6.weight", "cnn.batchnorm6.bias", "cnn.batchnorm6.running_mean", "cnn.batchnorm6.running_var", "rnn.0.rnn.weight_ih_l0", "rnn.0.rnn.weight_hh_l0", "rnn.0.rnn.bias_ih_l0", "rnn.0.rnn.bias_hh_l0", "rnn.0.rnn.weight_ih_l0_reverse", "rnn.0.rnn.weight_hh_l0_reverse", "rnn.0.rnn.bias_ih_l0_reverse", "rnn.0.rnn.bias_hh_l0_reverse", "rnn.0.embedding.weight", "rnn.0.embedding.bias", "rnn.1.rnn.weight_ih_l0", "rnn.1.rnn.weight_hh_l0", "rnn.1.rnn.bias_ih_l0", "rnn.1.rnn.bias_hh_l0", "rnn.1.rnn.weight_ih_l0_reverse", "rnn.1.rnn.weight_hh_l0_reverse", "rnn.1.rnn.bias_ih_l0_reverse", "rnn.1.rnn.bias_hh_l0_reverse", "rnn.1.embedding.weight", "rnn.1.embedding.bias".
Unexpected key(s) in state_dict: "module.cnn.conv0.weight", "module.cnn.conv0.bias", "module.cnn.conv1.weight", "module.cnn.conv1.bias", "module.cnn.conv2.weight", "module.cnn.conv2.bias", "module.cnn.batchnorm2.weight", "module.cnn.batchnorm2.bias", "module.cnn.batchnorm2.running_mean", "module.cnn.batchnorm2.running_var", "module.cnn.batchnorm2.num_batches_tracked", "module.cnn.conv3.weight", "module.cnn.conv3.bias", "module.cnn.conv4.weight", "module.cnn.conv4.bias", "module.cnn.batchnorm4.weight", "module.cnn.batchnorm4.bias", "module.cnn.batchnorm4.running_mean", "module.cnn.batchnorm4.running_var", "module.cnn.batchnorm4.num_batches_tracked", "module.cnn.conv5.weight", "module.cnn.conv5.bias", "module.cnn.conv6.weight", "module.cnn.conv6.bias", "module.cnn.batchnorm6.weight", "module.cnn.batchnorm6.bias", "module.cnn.batchnorm6.running_mean", "module.cnn.batchnorm6.running_var", "module.cnn.batchnorm6.num_batches_tracked", "module.rnn.0.rnn.weight_ih_l0", "module.rnn.0.rnn.weight_hh_l0", "module.rnn.0.rnn.bias_ih_l0", "module.rnn.0.rnn.bias_hh_l0", "module.rnn.0.rnn.weight_ih_l0_reverse", "module.rnn.0.rnn.weight_hh_l0_reverse", "module.rnn.0.rnn.bias_ih_l0_reverse", "module.rnn.0.rnn.bias_hh_l0_reverse", "module.rnn.0.embedding.weight", "module.rnn.0.embedding.bias", "module.rnn.1.rnn.weight_ih_l0", "module.rnn.1.rnn.weight_hh_l0", "module.rnn.1.rnn.bias_ih_l0", "module.rnn.1.rnn.bias_hh_l0", "module.rnn.1.rnn.weight_ih_l0_reverse", "module.rnn.1.rnn.weight_hh_l0_reverse", "module.rnn.1.rnn.bias_ih_l0_reverse", "module.rnn.1.rnn.bias_hh_l0_reverse", "module.rnn.1.embedding.weight", "module.rnn.1.embedding.bias".

Process finished with exit code 1

原因在於我們保存的pth權重名字多了module.去掉就好。
需要改成如下:

nclass = len(alphabet) + 1

model = crnn.CRNN(32, 1, nclass, 256)#model = crnn.CRNN(32, 1, 37, 256)
if torch.cuda.is_available():
    model = model.cuda()

#
# for m in model.state_dict().keys():
#      print("==:: ", m)

load_model_ = torch.load(model_path)
# for k, v in load_model_.items():
#     print(k,"  ::shape",v.shape)

state_dict_rename = collections.OrderedDict()
for k, v in load_model_.items():
    name = k[7:] # remove `module.`
    state_dict_rename[name] = v


print('loading pretrained model from %s' % model_path)
model.load_state_dict(state_dict_rename)

然后就可以測試了.
改動太多了,我把改好的代碼上傳git,有需要的下載。其中,放了10張測試圖片和label,可以完成轉lmdb。
https://github.com/wuzuowuyou/crnn_pytorch


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM