端到端的OCR:基於CNN的實現


端到端的OCR:基於CNN的實現

OCR是一個古老的問題。這里我們考慮一類特殊的OCR問題,就是驗證碼的識別。傳統做驗證碼的識別,需要經過如下步驟:

1. 二值化 2. 字符分割 3. 字符識別

這里最難的就是分割。如果字符之間有粘連,那分割起來就無比痛苦了。

最近研究深度學習,發現有人做端到端的OCR。於是准備嘗試一下。一般來說目前做基於深度學習的OCR大概有如下套路:

1. OCR的問題當做一個多標簽學習的問題。4個數字組成的驗證碼就相當於有4個標簽的圖片識別問題(這里的標簽還是有序的),用CNN來解決。 2. OCR的問題當做一個語音識別的問題,語音識別是把連續的音頻轉化為文本,驗證碼識別就是把連續的圖片轉化為文本,用CNN+LSTM+CTC來解決。

目前第1種方法可以做到90%多的准確率(4個都猜對了才算對),第二種方法我目前的實驗還只能到20%多,還在研究中。所以這篇文章先介紹第一種方法。

我們以python-captcha驗證碼的識別為例來做驗證碼識別。

下圖是一些這個驗證碼的例子:

python-captcha

可以看到這里面有粘連,也有形變,噪音。所以我們可以看看用CNN識別這個驗證碼的效果。

首先,我們定義一個迭代器來輸入數據,這里我們每次都直接調用python-captcha這個庫來根據隨機生成的label來生成相應的驗證碼圖片。這樣我們的訓練集相當於是無窮大的。

class OCRIter(mx.io.DataIter): def __init__(self, count, batch_size, num_label, height, width): super(OCRIter, self).__init__() self.captcha = ImageCaptcha(fonts=['./data/OpenSans-Regular.ttf']) self.batch_size = batch_size self.count = count self.height = height self.width = width self.provide_data = [('data', (batch_size, 3, height, width))] self.provide_label = [('softmax_label', (self.batch_size, num_label))] def __iter__(self): for k in range(self.count / self.batch_size): data = [] label = [] for i in range(self.batch_size): # 生成一個四位數字的隨機字符串 num = gen_rand() # 生成隨機字符串對應的驗證碼圖片 img = self.captcha.generate(num) img = np.fromstring(img.getvalue(), dtype='uint8') img = cv2.imdecode(img, cv2.IMREAD_COLOR) img = cv2.resize(img, (self.width, self.height)) cv2.imwrite("./tmp" + str(i % 10) + ".png", img) img = np.multiply(img, 1/255.0) img = img.transpose(2, 0, 1) data.append(img) label.append(get_label(num)) data_all = [mx.nd.array(data)] label_all = [mx.nd.array(label)] data_names = ['data'] label_names = ['softmax_label'] data_batch = OCRBatch(data_names, data_all, label_names, label_all) yield data_batch def reset(self): pass

然后我們用如下的網絡來訓練這個數據集:

def get_ocrnet(): data = mx.symbol.Variable('data') label = mx.symbol.Variable('softmax_label') conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=32) pool1 = mx.symbol.Pooling(data=conv1, pool_type="max", kernel=(2,2), stride=(1, 1)) relu1 = mx.symbol.Activation(data=pool1, act_type="relu") conv2 = mx.symbol.Convolution(data=relu1, kernel=(5,5), num_filter=32) pool2 = mx.symbol.Pooling(data=conv2, pool_type="avg", kernel=(2,2), stride=(1, 1)) relu2 = mx.symbol.Activation(data=pool2, act_type="relu") conv3 = mx.symbol.Convolution(data=relu2, kernel=(3,3), num_filter=32) pool3 = mx.symbol.Pooling(data=conv3, pool_type="avg", kernel=(2,2), stride=(1, 1)) relu3 = mx.symbol.Activation(data=pool3, act_type="relu") flatten = mx.symbol.Flatten(data = relu3) fc1 = mx.symbol.FullyConnected(data = flatten, num_hidden = 512) fc21 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10) fc22 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10) fc23 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10) fc24 = mx.symbol.FullyConnected(data = fc1, num_hidden = 10) fc2 = mx.symbol.Concat(*[fc21, fc22, fc23, fc24], dim = 0) label = mx.symbol.transpose(data = label) label = mx.symbol.Reshape(data = label, target_shape = (0, )) return mx.symbol.SoftmaxOutput(data = fc2, label = label, name = "softmax")

上面這個網絡要稍微解釋一下。因為這個問題是一個有順序的多label的圖片分類問題。我們在fc1的層上面接了4個Full Connect層(fc21,fc22,fc23,fc24),用來對應不同位置的4個數字label。然后將它們Concat在一起。然后同時學習這4個label。目前用上面的網絡訓練,4位數字全部預測正確的精度可以達到90%左右。

全部的代碼請參考 https://gist.github.com/xlvector/6923ef145e59de44ed06f21228f2f879


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM