文本檢測網絡CTPN學習(一)


   CTPN文字檢測網絡,是在2016年的論文Detecting Text in Natural Image with Connectionist Text Proposal Network中提出,其在Fast-rcnn的基礎上進行改進,提出了一種適合檢測文字的神經網絡,算是一篇開創性的論文,影響了后面文本檢測算法的方向。其對橫向文本的檢測能力很好,目前也常用於文檔,合同和發票等領域的的文本檢測。

  關於CTPN文字檢測方法,可以從下面五個方面來進行理解:網絡結構,anchor的正負樣本分配,標注數據前處理,loss函數,文本線構造算法

1. 網絡結構

   原始論文中CTPN的結構如下,網絡最后輸出包括三部分,scores表示是否文文本區域的置信度,vertical coordinates表示每一個box的中心點x坐標和高度,side-refinement表示對於左右兩側邊界處box的x坐標偏移值

   目前很多CTPN實現代碼,網絡輸出都只包括兩部分,scores和boxes兩部分,scores表示是否為文本區域的置信度,boxes表示對box的中心點x坐標,y坐標,高度和寬度(和通用目標檢測一樣)。相比於原始論文方法,這種方式對於網絡來說,學習起來困難一點,但對於每一個box都進行更見准確的偏移修正,結果應該會更加精確。實際工作中,我主要也使用這種方法,其結構如下:

   網絡結構的數據流程圖如下:

  1.尺寸為(1, 3, 600, 900)的圖片經過vgg_base提取特征,得到尺寸為(1, 512, 37, 56), 再經過一層卷積后尺寸為(1, 512*9, 37, 56)

  2. 尺寸為(1, 512*9, 37, 56)的特征圖經過RNN,輸出尺寸為(1, 256, 37, 56),  再經過一層卷積后尺寸為(1, 512, 37, 56)

  3尺寸為(1, 512, 37, 56)的特征圖,分別經過loc和score兩個分支卷積,經過loc分支得到(1, 40, 37, 56),這里的通道數40表示10個anchor,每個anchor包括(center_x, centert_y, w, h); 經過score分支得到(1, 20, 37, 56),20表示10個anchor,每個anchor包括文本區域和背景兩個類別

  對於連接vgg_base和RNN的那個卷積需要注意下,原始論文中采用caffe的img2col, 其過程如下:

   img2col參考代碼:

#pytorch實現im2col
class Im2col(nn.Module):
    def __init__(self, kernel_size, stride, padding):
        super(Im2col, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

    def forward(self, x):
        height = x.shape[2]
        x = F.unfold(x, self.kernel_size, padding=self.padding, stride=self.stride)
        x = x.reshape((x.shape[0], x.shape[1], height, -1))
        return x
Img2col實現

2. anchor的正負樣本分配

anchor設置

  CTPN的anchor共設置了10中比例的anchor,這些anchor的寬度都為16, 高度從11一直到283。之所以將anchor的寬度設置為16,是因為CTPN網絡將600*900的圖片提取特征后,最后輸出的特征圖尺寸為37*56,縮小了16倍,特征圖的感受野為16,即特征圖上一個像素點對應原始圖片上一個16*16的區域。

  CTPN的anchor設置如下圖所示,特征圖每個像素點處設置10個anchor,總共20720個anchor:

   產生anchor的代碼如下:

#coding:utf-8
import numpy as np
from gluoncv.nn.coder import SigmoidClassEncoder, NumPyNormalizedBoxCenterEncoder
import mxnet as mx
from mxnet import gluon
try:
    import cython_bbox
except ImportError:
    cython_bbox = None


class AnchorGenerator(gluon.HybridBlock):

    def __init__(self, anchor_height=[11, 16, 23, 33, 48, 68, 97, 139, 198, 283], anchor_width=16,
                 stride=16, img_size=(), alloc_size=(128, 128), clip=False):
        super(AnchorGenerator, self).__init__()
        # anchor_height = [11, 16, 22, 32, 46, 66, 94, 134, 191, 273] #原始論文中采用這個(from 11 to 273, divide 0.7 each time)
        self.anchor_height = anchor_height
        self.anchor_width = anchor_width
        self.stride = stride
        self.alloc_size = alloc_size
        self._im_size = img_size
        self.base_size = stride
        anchors = self.generate_anchor()
        self.anchors = self.params.get_constant('anchor', anchors)
        self._clip=clip

    def generate_anchor(self):

        base_anchors = self.generate_base_anchors()
        # print(base_anchors)
        # propagete to all locations by shifting offsets
        height, width = self.alloc_size
        offset_x = np.arange(0, width * self.stride, self.stride)
        offset_y = np.arange(0, height * self.stride, self.stride)
        offset_x, offset_y = np.meshgrid(offset_x, offset_y)
        offsets = np.stack((offset_x.ravel(), offset_y.ravel(),
                            offset_x.ravel(), offset_y.ravel()), axis=1)
        # broadcast_add (1, N, 4) + (M, 1, 4)
        anchors = (base_anchors.reshape((1, -1, 4)) + offsets.reshape((-1, 1, 4)))  # (37*56)*10*4
        anchors = anchors.reshape((1, 1, height, width, -1)).astype(np.float32)  # (1, 1, 37, 56, 40)
        # print(anchors.shape)
        return mx.nd.array(anchors)

    def generate_base_anchors(self):
        base_anchor = np.array([1, 1, self.base_size, self.base_size], dtype=np.float) - 1
        anchors = np.zeros((len(self.anchor_height), 4), np.float)
        for i, h in enumerate(self.anchor_height):
            anchors[i] = self.scale_anchor(base_anchor, h, self.anchor_width)
        return anchors

    def scale_anchor(self, base_anchor, h, w):
        center_x = (base_anchor[0]+base_anchor[2])*0.5
        center_y = (base_anchor[1]+base_anchor[3])*0.5
        scaled_anchor = np.zeros_like(base_anchor, dtype=np.int32)   #注意此處的整型
        scaled_anchor[0] = center_x - w/2
        scaled_anchor[2] = center_x + w/2
        scaled_anchor[1] = center_y - h/2
        scaled_anchor[3] = center_y + h/2
        return scaled_anchor

    def hybrid_forward(self, F, x, anchors):
        a = F.slice_like(anchors, x * 0, axes=(2, 3))
        a = a.reshape((1, -1, 4))

        if self._clip:
            cx, cy, cw, ch = a.split(axis=-1, num_outputs=4)
            H, W = self._im_size
            a = F.concat(*[cx.clip(0, W), cy.clip(0, H), cw.clip(0, W), ch.clip(0, H)], dim=-1)
        return a.reshape((1, -1, 4))


if __name__ == "__main__":
    import cv2
    import random
    ag = AnchorGenerator()
    print(ag.anchors.shape)
    x = mx.nd.uniform(shape=(1, 3, 37, 56))
    ag.initialize()
    anchor = ag(x)
    img = np.ones(shape=(600, 900, 3), dtype=np.uint8)*255
    for i in range(0, 2000): #只畫出2000個anchor
        print(anchor[0, i, :])
        box = anchor[0, i,:]
        box = box.asnumpy()
        color = (random.randint(0, 255), random.randint(0, 255),random.randint(0, 255))
        cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, 2)

    cv2.imshow("img", img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
anchor產生代碼

正負樣本分配

  CTPN采用了Fast-Rcnn的RPN網路一樣的樣本分配規則,即根據anchor和gt_box的IOU,挑選出256個anchor作為樣本給RPN網絡學習。需要注意的是挑選的anchor樣本數量,原始Fast-Rcnn中挑選出256個樣本,正負樣本各一半,對於CTPN,原始文字標注框需要切割成寬度為16的小框,樣本數會很多,所以可以根據自己數據的特點,自己設置挑選anchor樣本的總數。這里還是以挑選256個anchor為例,anchor挑選流程如下:

1. 去掉anchor中坐標超出圖片邊界的(圖片為600*900)
2. 計算所有anchor和gt_box的IOU,和gt_box具有最大IOU的anchor為正樣本(無論是否滿足IOU>0.7),剩余的anchor, IOU>0.7的為正樣本,0<IOU<0.3的為負樣本
3. 挑選出256個樣本,正負樣本各128個。(若正樣本不夠128個時,有多少取多少,若正本超過128個,隨機選取128個正樣本,多余的標注未忽略樣本;負樣本一般會多余128個,隨機選取128個負樣本,多余的標注未忽略樣本)
(最后會出現兩種情況,一是正負樣本各128個,總共256個樣本;二是正樣本少於128個(如50個),負樣本128個,總樣本少於256個)

3. 標注數據前處理

  由於原始數據的標簽都是一個大的文本框,需要拆分為寬度為16的小框,這樣才能用來訓練CTPN網絡,所以需要對標注數據進行預處理。大致步驟如下:

  1.找到原始標注框big_box的中心點,然后向兩邊按8的步長進行擴充,寬度16為一個small_box,直到big_box兩邊的邊界(對於靠近圖片邊界處,若小於16,不夠組成一個small_box的舍棄掉)

  2.對於划分后box的上下邊界不太好確定,可以在一個全黑的mask中把big_box畫上去(白色),然后從上往下和從下往上找到第一個白色像素點的位置作為該anchor的上下邊界

  划分成寬度為16的small_box如下:

  參考代碼如下:(參考:https://www.cnblogs.com/skyfsm/p/10054386.html)

#coding:utf-8
import os
import cv2
import math
import numpy as np

def get_line_func(point1, point2):
    assert point2[0]-point1[0]!=0
    a = (point2[1]-point1[1])/(point2[0]-point1[0])
    b = point2[1]-a*point2[0]
    return a, b

def get_top_bottom1(top_a, top_b, bottom_a, bottom_b, left_x, right_x):
    top_y = math.ceil(max(top_a *left_x + top_b, top_a * right_x + top_b))
    bottom_y = math.floor(min(bottom_a * left_x + bottom_b, bottom_a * right_x + bottom_b))
    return top_y, bottom_y

def get_top_bottom(height, width, points, left_x, right_x):
    #在一個全黑的mask中把文本框label畫上去(白色),然后從上往下和從下往上找到第一個白色像素點的位置作為該anchor的上下邊界;
    mask = np.zeros((height, width), dtype=np.uint8)
    points = np.array([int(i) for i in points])
    min_y = min(points[1::2])
    max_y = max(points[1::2])
    points = points.reshape(4, 2)
    for i in range(4):
        cv2.line(mask, (points[i][0], points[i][1]), (points[(i + 1) % 4][0], points[(i + 1) % 4][1]), 255,  2)
    flag = False
    top_y, bottom_y = 0, 0
    for y in range(min_y, min(max_y+1, height)):
    # for y in range(0, height):
        for x in range(left_x, min(right_x+1, width)):
            if mask[y, x] == 255:
                top_y = y
                flag=True
                break
        if flag:
            break

    flag = False
    for y in range(min(max_y, height-1), min_y-1, -1):
    # for y in range(height-1, -1, -1):
        for x in range(left_x, min(right_x + 1, width)):
            if mask[y, x] == 255:
                bottom_y = y
                flag = True
                break
        if flag:
            break

    # cv2.imshow("mask", mask)
    # cv2.waitKey(0)
    # cv2.destroyAllWindows()

    return top_y, bottom_y


def make_ctpn_data(img_file, anno_file, save_dir):
    try:
        img = cv2.imread(img_file)
        height, width = img.shape[:2]
        total_box_list = []
        with open(anno_file, "r", encoding="utf-8") as f:
            lines = f.readlines()
            for line in lines:
                small_box_list = []
                line_list = line.strip().split(",")
                points = [float(i) for i in line_list[:8]]

                validate_clockwise_points(points)  #驗證坐標是否為逆時針方向排序,否則報錯

                left_x = min(points[0], points[2])
                right_x = max(points[4], points[6])
                center_x = int((left_x + right_x)/2)
                l_temp, r_temp = center_x-8, center_x+8
                # top_line_a, top_line_b = get_line_func(points[:2], points[6:])  #原始big box上邊界直線方程
                # bottom_line_a, bottom_line_b = get_line_func(points[2:4], points[4:6]) #原始big box下邊界直線方程
                # top_y, bottom_y = get_top_bottom(top_line_a, top_line_b, bottom_line_a, bottom_line_b,l_temp, r_temp)
                top_y, bottom_y = get_top_bottom(height, width, points, l_temp, r_temp)
                small_box_list.append([center_x-8, top_y, center_x+8, bottom_y, 0])
                while l_temp-16 >= left_x:
                    top_y, bottom_y = get_top_bottom(height, width, points, l_temp-16, l_temp)
                    small_box_list.insert(0, [l_temp-16, top_y, l_temp, bottom_y, 0])  #0表示是中間box,沒有偏移值
                    l_temp = l_temp -16
                if l_temp - 16 >= 0 and l_temp > left_x:
                    top_y, bottom_y = get_top_bottom(height, width, points, l_temp - 16, l_temp)
                    small_box_list.insert(0, [l_temp-16, top_y, l_temp, bottom_y, (left_x-(l_temp-16))]) # 左邊邊界處的box,計算偏移值
                else:
                    # 邊界處小於16像素的舍棄掉
                    small_box_list[0][-1] = left_x-(l_temp-16)           # 左邊邊界處的box,計算偏移值
                while r_temp + 16 <= right_x:
                    top_y, bottom_y = get_top_bottom(height, width, points, r_temp, r_temp+16)
                    small_box_list.append([r_temp, top_y, r_temp+16, bottom_y, 0])
                    r_temp += 16
                if r_temp + 16 <= width-1 and r_temp < right_x:
                    top_y, bottom_y = get_top_bottom(height, width, points, r_temp, r_temp+16)
                    small_box_list.append([r_temp, top_y, r_temp+16, bottom_y, (right_x-r_temp)]) # 右邊邊界處的box,計算偏移值
                else:
                    # 邊界處小於16像素的舍棄掉
                    small_box_list[-1][-1] = right_x-r_temp  # 右邊邊界處的box,計算偏移值
                # print(small_box_list)
                total_box_list.extend(small_box_list)
    except Exception as e:
        print(e)
        print(anno_file)
        return
    name = os.path.basename(anno_file)
    with open(os.path.join(save_dir, name), "w", encoding="utf-8") as f:
        for box in total_box_list:
            box = [str(i) for i in box]
            f.write(",".join(box)+"\n")



def validate_clockwise_points(points):  #順時針排序時報錯
    """
    Validates that the points that the 4 points that dlimite a polygon are in counter_clockwise order.
    """

    #鞋帶定理(Shoelace Theorem)能根據多邊形的頂點坐標,計算任意多邊形的面積,坐標順時針排列時為負數,逆時針排列時為正數

    if len(points) != 8:
        raise Exception("Points list not valid." + str(len(points)))

    point = [
        [int(points[0]), int(points[1])],
        [int(points[2]), int(points[3])],
        [int(points[4]), int(points[5])],
        [int(points[6]), int(points[7])]
    ]
    edge = [
        (point[1][0] - point[0][0]) * (point[1][1] + point[0][1]),
        (point[2][0] - point[1][0]) * (point[2][1] + point[1][1]),
        (point[3][0] - point[2][0]) * (point[3][1] + point[2][1]),
        (point[0][0] - point[3][0]) * (point[0][1] + point[3][1])
    ]

    summatory = edge[0] + edge[1] + edge[2] + edge[3]
    if summatory < 0:
        raise Exception("Points are not counter_clockwise.")

        #轉換為逆時針方向
        # print('points in wrong direction')
        # poly = np.array(points).reshape((4, 2))
        # poly = poly[(0, 3, 2, 1), :]


if __name__ == "__main__":
    img_dir = r"E:\data\image_9000"
    src_label_dir = r"E:\data\txt_9000"
    dst_label_dir = r"E:data\txt_ctpn"
    for file in os.listdir(img_dir):
        if file.endswith(".jpg"):
            img_file = os.path.join(img_dir, file)
            name, _ = os.path.splitext(file)
            # anno_file = os.path.join(src_label_dir, file.replace(".jpg", ".txt"))
            anno_file = os.path.join(src_label_dir, name+".txt")
            make_ctpn_data(img_file, anno_file, dst_label_dir)
標注數據前處理代碼

4. loss函數

     原始論文中的loss包括了三部分的loss,文本區域的分類損失cls_loss, box的中心點x和高度損失vertical_loss,  box兩側的偏差損失side_refinment_loss。分類損失采用交叉熵,box回歸損失采用smoothL1.

  目前的CTPN實現代碼里,對於box,直接回歸box的中心點,高度和寬度,損失包括分類損失和box回歸損失,分類損失采用交叉熵,box回歸損失采用smoothL1.

5. 文本線構造算法

  文本線構造算法主要分為兩部分,首先是文本框連接,即將網絡輸出的box進行合並成一個大box,二是文本框矯正,即對這個box上下邊界進行修正,並通過修正后的平行四邊形得到最終的矩形

  文本框連接

    看下這篇文章https://zhuanlan.zhihu.com/p/34757009, 再結合代碼應該就能理解,步驟搬運過來如下:

  文本框修正

    看下這篇文章https://zhuanlan.zhihu.com/p/137540923, 再結合代碼應該就能理解, 步驟搬運過來如下:

 

參考文章:

  https://zhuanlan.zhihu.com/p/34757009

  https://zhuanlan.zhihu.com/p/137540923

  https://www.cnblogs.com/skyfsm/p/9776611.html


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM