1、同文章中建議的使用ubuntu-python隔離環境,真的很好用
參照:http://blog.topspeedsnail.com/archives/5618
啟動虛擬環境:
source env/bin/activate
退出虛擬環境:
deactivate
注意:下面的操作全部都要在隔離環境中完成
2、搭建虛擬環境
pip install -r(requests)應該是安裝request中所有的包
pip install Cython == 0.26
sudo apt-get install python3-dev
editdistance == 0.3.13、
3、
參照,編譯百度warpctc
http://blog.csdn.net/amds123/article/details/73433926
git clone
https://github.com/baidu-research/warp-ctc.git
cd warp-ctc
mkdir build
cd build
cmake ..
make
sudo make install
執行文章中snt-orc
mxnet/metrics/ctc` and run `python setup.py build_ext --inplace`
4、
編譯MXNET:
git clonr --recursive mxnet
cd mxnet
git tag
git checkout v0.9.3
按照論文中的方法編譯失敗,只能下載新版本編譯
新版本編譯步驟參考:https://www.bbsmax.com/A/A7zgqGk54n/
安裝依賴:
$ sudo apt-get install -y build-essential git
$ sudo apt-get install -y libopenblas-dev
$ sudo apt-get install -y libopencv-dev
git clone --recursive https://github.com/dmlc/mxnet.git
cd mxnet
cp make/*.ck ./(編譯選項文件)
vim *(按需修改編譯文件)文章要求加入warpctc
https://mxnet.incubator.apache.org/tutorials/speech_recognition/baidu_warp_ctc.html
make -j4
5、
編譯python接口參照
http://blog.csdn.net/zziahgf/article/details/72729883
編譯 MXNet的Python API:
安裝所需包
sudo apt-get install -y python-dev python-setuptools python-numpy
cd python
sudo python setup.py install
6、
下載stn-orc網絡:https://github.com/Bartzi/stn-ocr
這個網絡感覺跟FCN使用差不多,應該不需要什么格外操作
7、
下載model
https://bartzi.de/research/stn-ocr
中的文本識別:會有model文件夾,測試數據集
model文件夾中有兩個文件
*.params是模型文件,*.json應該是網絡描述文件
測試數據集中有圖片文件夾,gt文件,還有一個不知道是什么用
還需要一個文件stn-orc網絡中data文件對應‘文本’中應有個char_map文件,后面需要
模型預測代碼就是stn-orc文件下的eva的py代碼,看名字就知道,不過由於之前下載的是新版本,跟文中不同,所以使用這里的py文件沒有運行成功,仿照文件自己寫了一個簡單的測試文件:
import matplotlib.pyplot as plt import argparse import csv import json import os from collections import namedtuple from PIL import Image import editdistance import mxnet as mx import numpy as np from callbacks.save_bboxes import BBOXPlotter from metrics.ctc_metrics import strip_prediction from networks.text_rec import SVHNMultiLineCTCNetwork from operations.disable_shearing import * from utils.datatypes import Size Batch = namedtuple('Batch', ['data']) #后綴都不能加的,程序自己添加,似乎同時加載兩個文件 sym,arg_params,aux_params = mx.model.load_checkpoint('./testxt/model/model',2) #這里面應該是訓練的參數 #print(arg_params) net, loc, transformed_output, size_params = SVHNMultiLineCTCNetwork.get_network((1,1,64,200),Size(50,50),46,2,23) output = mx.sym.Group([loc, transformed_output, net]) #靠 在這里預定義的話,TMD,soft 層怎么辦? mod = mx.mod.Module(output,context=mx.cpu(),data_names=['data', 'softmax_label', 'l0_forward_init_h_state', 'l0_forward_init_c_state_cell', 'l1_forward_init_h_state', 'l1_forward_init_c_state_cell' ],label_names=[]) mod.bind(for_training=False,grad_req='null',data_shapes=[ ('data',(1,1,64,200)), ('softmax_label', (1,23)), ('l0_forward_init_h_state', (1, 1, 256)), ('l0_forward_init_c_state_cell', (1, 1, 256)), ('l1_forward_init_h_state', (1, 1, 256)), ('l1_forward_init_c_state_cell', (1, 1, 256)) ]) arg_params['l0_forward_init_h_state'] = mx.nd.zeros((1, 1, 256)) arg_params['l0_forward_init_c_state_cell'] = mx.nd.zeros((1, 1, 256)) arg_params['l1_forward_init_h_state'] = mx.nd.zeros((1, 1, 256)) arg_params['l1_forward_init_c_state_cell'] = mx.nd.zeros((1, 1, 256)) mod.set_params(arg_params, aux_params) #看看怎么加載label #一個映射文件,類似caffe中的label,在下面循環中用到 with open('/home/lbk/python-env/stn-ocr/mxnet/testxt/ctc_char_map.json') as char_map_file: char_map = json.load(char_map_file) reverse_char_map = {v: k for k, v in char_map.items()} print(len(reverse_char_map)) with open('/home/lbk/python-env/stn-ocr/mxnet/testxt/icdar2013_eval/one_gt.txt') as eval_gt: reader = csv.reader(eval_gt,delimiter=';') for idx,line in enumerate(reader): file_name = line[0] label = line[1].strip() gt_word = label.lower() print(gt_word) #這一步又是干什么的 #dict.get(key,default)查找,不存在返回default label = [reverse_char_map.get(ord(char.lower()),reverse_char_map[9250]) for char in gt_word] label+=[reverse_char_map[9250]]*(23-len(label)) #print(label) the_image = Image.open(file_name) the_image = the_image.convert('L') the_image = the_image.resize((200,64), Image.ANTIALIAS) image = np.asarray(the_image, dtype=np.float32)[np.newaxis, np.newaxis, ...] image/=255 temp = mx.nd.zeros((1,1,256)) label = mx.nd.array([label]) image = mx.nd.array(image) print(type(temp),type(label)) input_batch = Batch(data=[image,label,temp,temp,temp,temp]) mod.forward(input_batch,is_train=False) print(len(mod.get_outputs())) print('0000',mod.get_outputs()[2]) predictions = mod.get_outputs()[2].asnumpy() predicted_classes = np.argmax(predictions,axis=1) print(len(predicted_classes)) print(predicted_classes) predicted_classes = strip_prediction(predicted_classes, int(reverse_char_map[9250])) predicted_word = ''.join([chr(char_map[str(p)]) for p in predicted_classes]).replace(' ', '') print(predicted_word) distance = editdistance.eval(gt_word, predicted_word) print("{} - {}\t\t{}: {}".format(idx, gt_word, predicted_word, distance)) results = [prediction == label for prediction, label in zip(predicted_word, gt_word)] print(results)
補充:
學習MXNET:
http://www.infoq.com/cn/articles/an-introduction-to-the-mxnet-api-part04
http://blog.csdn.net/yiweibian/article/details/72678020
http://ysfalo.github.io/2016/04/01/mxnet%E4%B9%8Bfine-tune/
http://shuokay.com/2016/01/01/mxnet-memo/