OCR(Optical Character Recognition)任務主要是識別出圖片中的文字,目前深度學習的方法采用兩步來解決這個問題,一是文字檢測網絡定位文字位置,二是文字識別網絡識別出文字。
關於OCR的綜述參考:http://xiaofengshi.com/2019/01/05/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0-OCR_Overview/
CRNN+CTC的文字識別網絡是在2015年的論文An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition(2015) 中提出,主要用於序列文本的識別。CRNN的整體流程如下圖所示,圖片依次經過CNN卷積層,RNN循環層,最后經解碼翻譯處理得到最后的識別文本。
對於CRNN文字識別網絡的理解主要在於三方面:網絡結構,CTC損失函數,數據預處理。CRNN參考代碼地址:https://github.com/bgshih/crnn, https://github.com/meijieru/crnn.pytorch
1. 網絡結構
CRNN的網絡結構比較簡單,包括VGG11和RNN兩部分。采用VGG11進行特征提取,隨后采用雙層的BiLSTM提取序列信息,其網絡結構如圖所示:
訓練時模型的計算流程如下:
1. 經過灰度化和resize后圖片的尺寸為(B, 1, 32, 160),圖片經過VGG11卷積層得到feature尺寸為(B, 512, 1, 41)
2. feature經過RNN循環層后網絡輸出尺寸為(41, B, 4039)。(4039表示字典里共有4038個字符,還有一個字符"_"表示空格)
3. 尺寸為(41, B, 4039)的輸出經過log_softmax后,通過CTC計算loss
2.CTC損失函數
CTC(Connectionist Temporal Classification)是在2006年的論文Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks中提出,引入了空白符號,解決了損失計算時,文字標注和網絡輸出之間的對齊問題。其原理比較復雜,需要專研下,參考:CTC原理理解(轉載)
3.數據預處理
由於CRNN對於序列行文本效果比較好,所以對於輸入的圖片除了進行resize和灰度化外,還有兩點要注意。
一是旋轉角度過大的序列文本,需要進行一定的旋轉,參考下面代碼

from scipy.ndimage import filters, interpolation from numpy import amin, amax def estimate_skew_angle(raw): """ 估計圖像文字角度 因為文字是水平排版,那么此位置圖像的行與行之間像素值的方差應該是最大。 原理大概是這樣,先對圖像進行二值化處理,然后計算圖像每行的均值向量,得到該向量的方差。如果圖像文字不存在文字傾斜(假設所有文字朝向一致),那么對應的方差應該是最大,找到方差最大對應的角度,就是文字的傾斜角度。 本項目中,只取了-15到15度,主要是計算速度的影響,如果不考慮性能,可以計算得更准備。 """ def resize_im(im, scale, max_scale=None): f = float(scale) / min(im.shape[0], im.shape[1]) if max_scale != None and f * max(im.shape[0], im.shape[1]) > max_scale: f = float(max_scale) / max(im.shape[0], im.shape[1]) return cv2.resize(im, (0, 0), fx=f, fy=f) raw = resize_im(raw, scale=600, max_scale=900) image = raw - amin(raw) image = image / amax(image) m = interpolation.zoom(image, 0.5) m = filters.percentile_filter(m, 80, size=(20, 2)) m = filters.percentile_filter(m, 80, size=(2, 20)) m = interpolation.zoom(m, 1.0 / 0.5) w, h = min(image.shape[1], m.shape[1]), min(image.shape[0], m.shape[0]) flat = np.clip(image[:h, :w] - m[:h, :w] + 1, 0, 1) d0, d1 = flat.shape o0, o1 = int(0.1 * d0), int(0.1 * d1) flat = amax(flat) - flat flat -= amin(flat) est = flat[o0:d0 - o0, o1:d1 - o1] angles = range(-15, 15) estimates = [] for a in angles: roest = interpolation.rotate(est, a, order=0, mode='constant') v = np.mean(roest, axis=1) v = np.var(v) estimates.append((v, a)) _, a = max(estimates) return a if __name__ == "__main__": import os src = r"F:\temp" for file in os.listdir(src): if file.endswith(".jpg"): img_path = os.path.join(src, file) # img = cv2.imread(img_path, 0) img = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), flags=0) h, w = img.shape angle = estimate_skew_angle(img) print(file, angle) m = cv2.getRotationMatrix2D((int(w / 2), int(h / 2)), angle, 1) d = int(np.sqrt(h * h + w * w)) img2 = cv2.warpAffine(img, m, (w, h)) cv2.imshow("img2", img2) cv2.waitKey(0) cv2.destroyAllWindows()
二是對於長文本需要進行切割成小段,識別后再拼接,如上面網絡輸出序列為41*B*4039,表示支持的最長文本為41個字符
參考:https://zhuanlan.zhihu.com/p/43534801