現有的文本檢測方法主要有兩大類,一種是基於回歸框的檢測方法(基於物體檢測的方法),如CTPN,EAST,這類方法很難檢測任意形狀的文本(曲線文本), 一種是基於像素的分割檢測器(基於實例分割的方法),這類方法很難將彼此非常接近的文本實例分開。Psenet文本檢測方法是基於分割的方法,在2019年的論文Shape Robust Text Detection with Progressive Scale Expansion Network 中提出,優化了近距離文本實例的分離。
對於Psenet的學習,主要在於四方面:網絡結構的設計,kernel的生成,漸進尺度擴展算法(progressive scale expansion),loss函數
1. 網絡結構的設計
Psenet網絡采用了resnet+fpn的架構,通過resnet提取特征,取不同層的特征送入fpn進行特征融合,其結構如下圖所示:
上圖中給出了訓練過程中網絡數據流,總結如下:
1. 1*3*640*640的圖片輸入網絡,經過Resnet網絡,將layer1,layer2,layer3,layer4的特征圖p1(1*256*160*160), p2(1*512*80*80), p3(1*1024*40*40), p4(1*2048*20*20)送入fpn
2. 以此對應p1, p2, p3, p4, fpn網絡輸出特征c1(1*256*160*160), c2(1*256*80*80), c3(1*256*40*40), c4(1*256*20*20)
3. c2, c3, c4分別上采樣2,4,8倍后和c1進行concat得到特征1*1024*160*160,再經過兩個卷積輸出1*7*160*160,上采樣4倍得到網絡最終的輸出1*7*640*640。
4.網絡最后輸出了7個640*640的預測圖(map),分別表示預測的text_predict,和6個kernel_predict
另外,上述采用resnet50的典型結構如下:
2. kernel的產生
上面網絡結構中提到模型最后輸出7個640*640的預測圖, 分別是預測的text,和6個kernel,因此在訓練時也需要通過標注數據產生7個640*640的map供網絡學習,即text_gt和6個kernel_gt。其中text_gt就是一張二值圖,白色部分表示img中含有文字的區域,黑色部分表示背景區域,kernel_gt就是在text_gt的基礎上,將白色區域按一定的比例縮小。如下圖所示,根據r計算出d,表示該kernel的白色區域邊緣部分相對於text_gt的白色區域向內部移動了d個像素。
3. 漸進尺度擴展算法(progressive scale expansion)
在進行推理時,需要從網絡輸出的6個kernel中得到需要的box,作者采用了pse(progressive scale exoansion)算法。假設有kernel1,kernel2, kernel3, kernel4, kernel5, kernel6,先從文字區域最小的kernel6開始,遍歷其白色區域的像素點,采用廣度優先法向四周擴展,依次合並kernel2, kernel3, kernel4, kernel5, kernel6, 最后合並得到一個kernel,整個合並算法看代碼比較好理解。取合並后kernel白色區域的矩形框或輪廓線即得到文字檢測框。論文中示意圖如下:
參考python代碼如下:

import numpy as np import cv2 # import Queue from queue import Queue def pse(kernals, min_area): kernal_num = len(kernals) pred = np.zeros(kernals[0].shape, dtype='int32') label_num, label = cv2.connectedComponents(kernals[kernal_num - 1], connectivity=4) for label_idx in range(1, label_num): if np.sum(label == label_idx) < min_area: label[label == label_idx] = 0 queue = Queue.Queue(maxsize = 0) next_queue = Queue.Queue(maxsize = 0) points = np.array(np.where(label > 0)).transpose((1, 0)) for point_idx in range(points.shape[0]): x, y = points[point_idx, 0], points[point_idx, 1] l = label[x, y] queue.put((x, y, l)) pred[x, y] = l dx = [-1, 1, 0, 0] dy = [0, 0, -1, 1] for kernal_idx in range(kernal_num - 2, -1, -1): kernal = kernals[kernal_idx].copy() while not queue.empty(): (x, y, l) = queue.get() is_edge = True for j in range(4): tmpx = x + dx[j] tmpy = y + dy[j] if tmpx < 0 or tmpx >= kernal.shape[0] or tmpy < 0 or tmpy >= kernal.shape[1]: continue if kernal[tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: continue queue.put((tmpx, tmpy, l)) pred[tmpx, tmpy] = l is_edge = False if is_edge: next_queue.put((x, y, l)) # kernal[pred > 0] = 0 queue, next_queue = next_queue, queue # points = np.array(np.where(pred > 0)).transpose((1, 0)) # for point_idx in range(points.shape[0]): # x, y = points[point_idx, 0], points[point_idx, 1] # l = pred[x, y] # queue.put((x, y, l)) return pred
4. loss函數理解
psenet的loss包括兩部分,gt_text和kernel的loss,都采用dice loss計算損失值。總的loss計算如公司如下,權重系數一般取λ=0.7
dice loss的計算公式如下,參見代碼比較好理解
dice loss 參考代碼:

def dice_loss(input, target, mask): #input為預測的map #target為標注的map input = torch.sigmoid(input) input = input.contiguous().view(input.size()[0], -1) target = target.contiguous().view(target.size()[0], -1) mask = mask.contiguous().view(mask.size()[0], -1) input = input * mask target = target * mask a = torch.sum(input * target, 1) b = torch.sum(input * input, 1) + 0.001 c = torch.sum(target * target, 1) + 0.001 d = (2 * a) / (b + c) dice_loss = torch.mean(d) return 1 - dice_loss
參考:
https://github.com/whai362/PSENet
https://github.com/WenmuZhou/PSENet.pytorch