CTPN文字檢測網絡,是在2016年的論文Detecting Text in Natural Image with Connectionist Text Proposal Network中提出,其在Fast-rcnn的基礎上進行改進,提出了一種適合檢測文字的神經網絡,算是一篇開創性的論文,影響了后面文本檢測算法的方向。其對橫向文本的檢測能力很好,目前也常用於文檔,合同和發票等領域的的文本檢測。
關於CTPN文字檢測方法,可以從下面五個方面來進行理解:網絡結構,anchor的正負樣本分配,標注數據前處理,loss函數,文本線構造算法
1. 網絡結構
原始論文中CTPN的結構如下,網絡最后輸出包括三部分,scores表示是否文文本區域的置信度,vertical coordinates表示每一個box的中心點x坐標和高度,side-refinement表示對於左右兩側邊界處box的x坐標偏移值
目前很多CTPN實現代碼,網絡輸出都只包括兩部分,scores和boxes兩部分,scores表示是否為文本區域的置信度,boxes表示對box的中心點x坐標,y坐標,高度和寬度(和通用目標檢測一樣)。相比於原始論文方法,這種方式對於網絡來說,學習起來困難一點,但對於每一個box都進行更見准確的偏移修正,結果應該會更加精確。實際工作中,我主要也使用這種方法,其結構如下:
網絡結構的數據流程圖如下:
1.尺寸為(1, 3, 600, 900)的圖片經過vgg_base提取特征,得到尺寸為(1, 512, 37, 56), 再經過一層卷積后尺寸為(1, 512*9, 37, 56)
2. 尺寸為(1, 512*9, 37, 56)的特征圖經過RNN,輸出尺寸為(1, 256, 37, 56), 再經過一層卷積后尺寸為(1, 512, 37, 56)
3尺寸為(1, 512, 37, 56)的特征圖,分別經過loc和score兩個分支卷積,經過loc分支得到(1, 40, 37, 56),這里的通道數40表示10個anchor,每個anchor包括(center_x, centert_y, w, h); 經過score分支得到(1, 20, 37, 56),20表示10個anchor,每個anchor包括文本區域和背景兩個類別
對於連接vgg_base和RNN的那個卷積需要注意下,原始論文中采用caffe的img2col, 其過程如下:
img2col參考代碼:

#pytorch實現im2col class Im2col(nn.Module): def __init__(self, kernel_size, stride, padding): super(Im2col, self).__init__() self.kernel_size = kernel_size self.stride = stride self.padding = padding def forward(self, x): height = x.shape[2] x = F.unfold(x, self.kernel_size, padding=self.padding, stride=self.stride) x = x.reshape((x.shape[0], x.shape[1], height, -1)) return x
2. anchor的正負樣本分配
anchor設置
CTPN的anchor共設置了10中比例的anchor,這些anchor的寬度都為16, 高度從11一直到283。之所以將anchor的寬度設置為16,是因為CTPN網絡將600*900的圖片提取特征后,最后輸出的特征圖尺寸為37*56,縮小了16倍,特征圖的感受野為16,即特征圖上一個像素點對應原始圖片上一個16*16的區域。
CTPN的anchor設置如下圖所示,特征圖每個像素點處設置10個anchor,總共20720個anchor:
產生anchor的代碼如下:

#coding:utf-8 import numpy as np from gluoncv.nn.coder import SigmoidClassEncoder, NumPyNormalizedBoxCenterEncoder import mxnet as mx from mxnet import gluon try: import cython_bbox except ImportError: cython_bbox = None class AnchorGenerator(gluon.HybridBlock): def __init__(self, anchor_height=[11, 16, 23, 33, 48, 68, 97, 139, 198, 283], anchor_width=16, stride=16, img_size=(), alloc_size=(128, 128), clip=False): super(AnchorGenerator, self).__init__() # anchor_height = [11, 16, 22, 32, 46, 66, 94, 134, 191, 273] #原始論文中采用這個(from 11 to 273, divide 0.7 each time) self.anchor_height = anchor_height self.anchor_width = anchor_width self.stride = stride self.alloc_size = alloc_size self._im_size = img_size self.base_size = stride anchors = self.generate_anchor() self.anchors = self.params.get_constant('anchor', anchors) self._clip=clip def generate_anchor(self): base_anchors = self.generate_base_anchors() # print(base_anchors) # propagete to all locations by shifting offsets height, width = self.alloc_size offset_x = np.arange(0, width * self.stride, self.stride) offset_y = np.arange(0, height * self.stride, self.stride) offset_x, offset_y = np.meshgrid(offset_x, offset_y) offsets = np.stack((offset_x.ravel(), offset_y.ravel(), offset_x.ravel(), offset_y.ravel()), axis=1) # broadcast_add (1, N, 4) + (M, 1, 4) anchors = (base_anchors.reshape((1, -1, 4)) + offsets.reshape((-1, 1, 4))) # (37*56)*10*4 anchors = anchors.reshape((1, 1, height, width, -1)).astype(np.float32) # (1, 1, 37, 56, 40) # print(anchors.shape) return mx.nd.array(anchors) def generate_base_anchors(self): base_anchor = np.array([1, 1, self.base_size, self.base_size], dtype=np.float) - 1 anchors = np.zeros((len(self.anchor_height), 4), np.float) for i, h in enumerate(self.anchor_height): anchors[i] = self.scale_anchor(base_anchor, h, self.anchor_width) return anchors def scale_anchor(self, base_anchor, h, w): center_x = (base_anchor[0]+base_anchor[2])*0.5 center_y = (base_anchor[1]+base_anchor[3])*0.5 scaled_anchor = np.zeros_like(base_anchor, dtype=np.int32) #注意此處的整型 scaled_anchor[0] = center_x - w/2 scaled_anchor[2] = center_x + w/2 scaled_anchor[1] = center_y - h/2 scaled_anchor[3] = center_y + h/2 return scaled_anchor def hybrid_forward(self, F, x, anchors): a = F.slice_like(anchors, x * 0, axes=(2, 3)) a = a.reshape((1, -1, 4)) if self._clip: cx, cy, cw, ch = a.split(axis=-1, num_outputs=4) H, W = self._im_size a = F.concat(*[cx.clip(0, W), cy.clip(0, H), cw.clip(0, W), ch.clip(0, H)], dim=-1) return a.reshape((1, -1, 4)) if __name__ == "__main__": import cv2 import random ag = AnchorGenerator() print(ag.anchors.shape) x = mx.nd.uniform(shape=(1, 3, 37, 56)) ag.initialize() anchor = ag(x) img = np.ones(shape=(600, 900, 3), dtype=np.uint8)*255 for i in range(0, 2000): #只畫出2000個anchor print(anchor[0, i, :]) box = anchor[0, i,:] box = box.asnumpy() color = (random.randint(0, 255), random.randint(0, 255),random.randint(0, 255)) cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, 2) cv2.imshow("img", img) cv2.waitKey(0) cv2.destroyAllWindows()
正負樣本分配
CTPN采用了Fast-Rcnn的RPN網路一樣的樣本分配規則,即根據anchor和gt_box的IOU,挑選出256個anchor作為樣本給RPN網絡學習。需要注意的是挑選的anchor樣本數量,原始Fast-Rcnn中挑選出256個樣本,正負樣本各一半,對於CTPN,原始文字標注框需要切割成寬度為16的小框,樣本數會很多,所以可以根據自己數據的特點,自己設置挑選anchor樣本的總數。這里還是以挑選256個anchor為例,anchor挑選流程如下:
1. 去掉anchor中坐標超出圖片邊界的(圖片為600*900) 2. 計算所有anchor和gt_box的IOU,和gt_box具有最大IOU的anchor為正樣本(無論是否滿足IOU>0.7),剩余的anchor, IOU>0.7的為正樣本,0<IOU<0.3的為負樣本 3. 挑選出256個樣本,正負樣本各128個。(若正樣本不夠128個時,有多少取多少,若正本超過128個,隨機選取128個正樣本,多余的標注未忽略樣本;負樣本一般會多余128個,隨機選取128個負樣本,多余的標注未忽略樣本) (最后會出現兩種情況,一是正負樣本各128個,總共256個樣本;二是正樣本少於128個(如50個),負樣本128個,總樣本少於256個)
3. 標注數據前處理
由於原始數據的標簽都是一個大的文本框,需要拆分為寬度為16的小框,這樣才能用來訓練CTPN網絡,所以需要對標注數據進行預處理。大致步驟如下:
1.找到原始標注框big_box的中心點,然后向兩邊按8的步長進行擴充,寬度16為一個small_box,直到big_box兩邊的邊界(對於靠近圖片邊界處,若小於16,不夠組成一個small_box的舍棄掉)
2.對於划分后box的上下邊界不太好確定,可以在一個全黑的mask中把big_box畫上去(白色),然后從上往下和從下往上找到第一個白色像素點的位置作為該anchor的上下邊界
划分成寬度為16的small_box如下:
參考代碼如下:(參考:https://www.cnblogs.com/skyfsm/p/10054386.html)

#coding:utf-8 import os import cv2 import math import numpy as np def get_line_func(point1, point2): assert point2[0]-point1[0]!=0 a = (point2[1]-point1[1])/(point2[0]-point1[0]) b = point2[1]-a*point2[0] return a, b def get_top_bottom1(top_a, top_b, bottom_a, bottom_b, left_x, right_x): top_y = math.ceil(max(top_a *left_x + top_b, top_a * right_x + top_b)) bottom_y = math.floor(min(bottom_a * left_x + bottom_b, bottom_a * right_x + bottom_b)) return top_y, bottom_y def get_top_bottom(height, width, points, left_x, right_x): #在一個全黑的mask中把文本框label畫上去(白色),然后從上往下和從下往上找到第一個白色像素點的位置作為該anchor的上下邊界; mask = np.zeros((height, width), dtype=np.uint8) points = np.array([int(i) for i in points]) min_y = min(points[1::2]) max_y = max(points[1::2]) points = points.reshape(4, 2) for i in range(4): cv2.line(mask, (points[i][0], points[i][1]), (points[(i + 1) % 4][0], points[(i + 1) % 4][1]), 255, 2) flag = False top_y, bottom_y = 0, 0 for y in range(min_y, min(max_y+1, height)): # for y in range(0, height): for x in range(left_x, min(right_x+1, width)): if mask[y, x] == 255: top_y = y flag=True break if flag: break flag = False for y in range(min(max_y, height-1), min_y-1, -1): # for y in range(height-1, -1, -1): for x in range(left_x, min(right_x + 1, width)): if mask[y, x] == 255: bottom_y = y flag = True break if flag: break # cv2.imshow("mask", mask) # cv2.waitKey(0) # cv2.destroyAllWindows() return top_y, bottom_y def make_ctpn_data(img_file, anno_file, save_dir): try: img = cv2.imread(img_file) height, width = img.shape[:2] total_box_list = [] with open(anno_file, "r", encoding="utf-8") as f: lines = f.readlines() for line in lines: small_box_list = [] line_list = line.strip().split(",") points = [float(i) for i in line_list[:8]] validate_clockwise_points(points) #驗證坐標是否為逆時針方向排序,否則報錯 left_x = min(points[0], points[2]) right_x = max(points[4], points[6]) center_x = int((left_x + right_x)/2) l_temp, r_temp = center_x-8, center_x+8 # top_line_a, top_line_b = get_line_func(points[:2], points[6:]) #原始big box上邊界直線方程 # bottom_line_a, bottom_line_b = get_line_func(points[2:4], points[4:6]) #原始big box下邊界直線方程 # top_y, bottom_y = get_top_bottom(top_line_a, top_line_b, bottom_line_a, bottom_line_b,l_temp, r_temp) top_y, bottom_y = get_top_bottom(height, width, points, l_temp, r_temp) small_box_list.append([center_x-8, top_y, center_x+8, bottom_y, 0]) while l_temp-16 >= left_x: top_y, bottom_y = get_top_bottom(height, width, points, l_temp-16, l_temp) small_box_list.insert(0, [l_temp-16, top_y, l_temp, bottom_y, 0]) #0表示是中間box,沒有偏移值 l_temp = l_temp -16 if l_temp - 16 >= 0 and l_temp > left_x: top_y, bottom_y = get_top_bottom(height, width, points, l_temp - 16, l_temp) small_box_list.insert(0, [l_temp-16, top_y, l_temp, bottom_y, (left_x-(l_temp-16))]) # 左邊邊界處的box,計算偏移值 else: # 邊界處小於16像素的舍棄掉 small_box_list[0][-1] = left_x-(l_temp-16) # 左邊邊界處的box,計算偏移值 while r_temp + 16 <= right_x: top_y, bottom_y = get_top_bottom(height, width, points, r_temp, r_temp+16) small_box_list.append([r_temp, top_y, r_temp+16, bottom_y, 0]) r_temp += 16 if r_temp + 16 <= width-1 and r_temp < right_x: top_y, bottom_y = get_top_bottom(height, width, points, r_temp, r_temp+16) small_box_list.append([r_temp, top_y, r_temp+16, bottom_y, (right_x-r_temp)]) # 右邊邊界處的box,計算偏移值 else: # 邊界處小於16像素的舍棄掉 small_box_list[-1][-1] = right_x-r_temp # 右邊邊界處的box,計算偏移值 # print(small_box_list) total_box_list.extend(small_box_list) except Exception as e: print(e) print(anno_file) return name = os.path.basename(anno_file) with open(os.path.join(save_dir, name), "w", encoding="utf-8") as f: for box in total_box_list: box = [str(i) for i in box] f.write(",".join(box)+"\n") def validate_clockwise_points(points): #順時針排序時報錯 """ Validates that the points that the 4 points that dlimite a polygon are in counter_clockwise order. """ #鞋帶定理(Shoelace Theorem)能根據多邊形的頂點坐標,計算任意多邊形的面積,坐標順時針排列時為負數,逆時針排列時為正數 if len(points) != 8: raise Exception("Points list not valid." + str(len(points))) point = [ [int(points[0]), int(points[1])], [int(points[2]), int(points[3])], [int(points[4]), int(points[5])], [int(points[6]), int(points[7])] ] edge = [ (point[1][0] - point[0][0]) * (point[1][1] + point[0][1]), (point[2][0] - point[1][0]) * (point[2][1] + point[1][1]), (point[3][0] - point[2][0]) * (point[3][1] + point[2][1]), (point[0][0] - point[3][0]) * (point[0][1] + point[3][1]) ] summatory = edge[0] + edge[1] + edge[2] + edge[3] if summatory < 0: raise Exception("Points are not counter_clockwise.") #轉換為逆時針方向 # print('points in wrong direction') # poly = np.array(points).reshape((4, 2)) # poly = poly[(0, 3, 2, 1), :] if __name__ == "__main__": img_dir = r"E:\data\image_9000" src_label_dir = r"E:\data\txt_9000" dst_label_dir = r"E:data\txt_ctpn" for file in os.listdir(img_dir): if file.endswith(".jpg"): img_file = os.path.join(img_dir, file) name, _ = os.path.splitext(file) # anno_file = os.path.join(src_label_dir, file.replace(".jpg", ".txt")) anno_file = os.path.join(src_label_dir, name+".txt") make_ctpn_data(img_file, anno_file, dst_label_dir)
4. loss函數
原始論文中的loss包括了三部分的loss,文本區域的分類損失cls_loss, box的中心點x和高度損失vertical_loss, box兩側的偏差損失side_refinment_loss。分類損失采用交叉熵,box回歸損失采用smoothL1.
目前的CTPN實現代碼里,對於box,直接回歸box的中心點,高度和寬度,損失包括分類損失和box回歸損失,分類損失采用交叉熵,box回歸損失采用smoothL1.
5. 文本線構造算法
文本線構造算法主要分為兩部分,首先是文本框連接,即將網絡輸出的box進行合並成一個大box,二是文本框矯正,即對這個box上下邊界進行修正,並通過修正后的平行四邊形得到最終的矩形
文本框連接
看下這篇文章https://zhuanlan.zhihu.com/p/34757009, 再結合代碼應該就能理解,步驟搬運過來如下:
文本框修正
看下這篇文章https://zhuanlan.zhihu.com/p/137540923, 再結合代碼應該就能理解, 步驟搬運過來如下:
參考文章:
https://zhuanlan.zhihu.com/p/34757009
https://zhuanlan.zhihu.com/p/137540923
https://www.cnblogs.com/skyfsm/p/9776611.html