文字識別網絡學習—CRNN+CTC


  OCR(Optical Character Recognition)任務主要是識別出圖片中的文字,目前深度學習的方法采用兩步來解決這個問題,一是文字檢測網絡定位文字位置,二是文字識別網絡識別出文字。

  關於OCR的綜述參考:http://xiaofengshi.com/2019/01/05/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0-OCR_Overview/

  CRNN+CTC的文字識別網絡是在2015年的論文An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition(2015) 中提出,主要用於序列文本的識別。CRNN的整體流程如下圖所示,圖片依次經過CNN卷積層,RNN循環層,最后經解碼翻譯處理得到最后的識別文本。

  對於CRNN文字識別網絡的理解主要在於三方面:網絡結構,CTC損失函數,數據預處理。CRNN參考代碼地址:https://github.com/bgshih/crnn, https://github.com/meijieru/crnn.pytorch  

1. 網絡結構

   CRNN的網絡結構比較簡單,包括VGG11和RNN兩部分。采用VGG11進行特征提取,隨后采用雙層的BiLSTM提取序列信息,其網絡結構如圖所示:

  訓練時模型的計算流程如下:

    1. 經過灰度化和resize后圖片的尺寸為(B, 1, 32, 160),圖片經過VGG11卷積層得到feature尺寸為(B, 512, 1, 41)

    2. feature經過RNN循環層后網絡輸出尺寸為(41, B, 4039)。(4039表示字典里共有4038個字符,還有一個字符"_"表示空格)

    3. 尺寸為(41, B, 4039)的輸出經過log_softmax后,通過CTC計算loss 

2.CTC損失函數

   CTC(Connectionist Temporal Classification)是在2006年的論文Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks中提出,引入了空白符號,解決了損失計算時,文字標注和網絡輸出之間的對齊問題。其原理比較復雜,需要專研下,參考:CTC原理理解(轉載)  

3.數據預處理

   由於CRNN對於序列行文本效果比較好,所以對於輸入的圖片除了進行resize和灰度化外,還有兩點要注意。

  一是旋轉角度過大的序列文本,需要進行一定的旋轉,參考下面代碼

from scipy.ndimage import filters, interpolation
from numpy import amin, amax


def estimate_skew_angle(raw):
    """
    估計圖像文字角度
    因為文字是水平排版,那么此位置圖像的行與行之間像素值的方差應該是最大。
    原理大概是這樣,先對圖像進行二值化處理,然后計算圖像每行的均值向量,得到該向量的方差。如果圖像文字不存在文字傾斜(假設所有文字朝向一致),那么對應的方差應該是最大,找到方差最大對應的角度,就是文字的傾斜角度。
    本項目中,只取了-15到15度,主要是計算速度的影響,如果不考慮性能,可以計算得更准備。
    """

    def resize_im(im, scale, max_scale=None):
        f = float(scale) / min(im.shape[0], im.shape[1])
        if max_scale != None and f * max(im.shape[0], im.shape[1]) > max_scale:
            f = float(max_scale) / max(im.shape[0], im.shape[1])
        return cv2.resize(im, (0, 0), fx=f, fy=f)

    raw = resize_im(raw, scale=600, max_scale=900)
    image = raw - amin(raw)
    image = image / amax(image)
    m = interpolation.zoom(image, 0.5)
    m = filters.percentile_filter(m, 80, size=(20, 2))
    m = filters.percentile_filter(m, 80, size=(2, 20))
    m = interpolation.zoom(m, 1.0 / 0.5)

    w, h = min(image.shape[1], m.shape[1]), min(image.shape[0], m.shape[0])
    flat = np.clip(image[:h, :w] - m[:h, :w] + 1, 0, 1)
    d0, d1 = flat.shape
    o0, o1 = int(0.1 * d0), int(0.1 * d1)
    flat = amax(flat) - flat
    flat -= amin(flat)
    est = flat[o0:d0 - o0, o1:d1 - o1]
    angles = range(-15, 15)
    estimates = []
    for a in angles:
        roest = interpolation.rotate(est, a, order=0, mode='constant')
        v = np.mean(roest, axis=1)
        v = np.var(v)
        estimates.append((v, a))

    _, a = max(estimates)
    return a

if __name__ == "__main__":
    import os

    src = r"F:\temp"
    for file in os.listdir(src):
        if file.endswith(".jpg"):
            img_path = os.path.join(src, file)
            # img = cv2.imread(img_path, 0)
            img = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), flags=0)
            h, w = img.shape
            angle = estimate_skew_angle(img)
            print(file, angle)
            m = cv2.getRotationMatrix2D((int(w / 2), int(h / 2)), angle, 1)
            d = int(np.sqrt(h * h + w * w))
            img2 = cv2.warpAffine(img, m, (w, h))
            cv2.imshow("img2", img2)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
小角度估計

  二是對於長文本需要進行切割成小段,識別后再拼接,如上面網絡輸出序列為41*B*4039,表示支持的最長文本為41個字符

 

參考:https://zhuanlan.zhihu.com/p/43534801


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM