這幾天一直在用Pytorch來復現文本檢測領域的CTPN論文,本文章將從數據處理、訓練標簽生成、神經網絡搭建、損失函數設計、訓練主過程編寫等這幾個方面來一步一步復現CTPN。CTPN算法理論可以參考這里。
訓練數據處理
我們的訓練選擇天池ICPR2018和MSRA_TD500兩個數據集,天池ICPR的數據集為網絡圖像,都是一些淘寶商家上傳到淘寶的一些商品介紹圖像,其標簽方式參考了ICDAR2015的數據標簽格式,即一個文本框用4個坐標來表示,即左上、右上、右下、左下四個坐標,共八個值,記作[x1 y1 x2 y2 x3 y3 x4 y4]
天池ICPR2018數據集的風格如下,字體形態格式顏色多變,多嵌套於物體之中,識別難度大:
MSRA_TD500使微軟收集的一個文本檢測和識別的一個數據集,里面的圖像多是街景圖,背景比較復雜,但文本位置比較明顯,一目了然。因為MSRA_TD500的標簽格式不一樣,最后一個參數表示矩形框的旋轉角度。
所以我們第一步就是將這兩個數據集的標簽格式統一,我的做法是將MSRA數據集格式改為ICDAR格式,方便后面的模型訓練。因為MSRA_TD500采取的標簽格式是[index difficulty_label x y w h angle],所以我們需要根據這個文本框的旋轉角度來求得水平文本框旋轉后的4個坐標位置。實現如下:
"""
This file is to change MSRA_TD500 dataset format to ICDAR2015 dataset format.
MSRA_TD500 format: [index difficulty_label x y w h angle]
ICDAR2015 format: [left_top_x left_top_y right_top_X right_top_y right_bottom_x right_bottom_y left_bottom_x left_bottom_y]
"""
import math
import cv2
import os
# 求旋轉后矩形的4個坐標
def get_box_img(x, y, w, h, angle):
# 矩形框中點(x0,y0)
x0 = x + w/2
y0 = y + h/2
l = math.sqrt(pow(w/2, 2) + pow(h/2, 2)) # 即對角線的一半
# angle小於0,逆時針轉
if angle < 0:
a1 = -angle + math.atan(h / float(w)) # 旋轉角度-對角線與底線所成的角度
a2 = -angle - math.atan(h / float(w)) # 旋轉角度+對角線與底線所成的角度
pt1 = (x0 - l * math.cos(a2), y0 + l * math.sin(a2))
pt2 = (x0 + l * math.cos(a1), y0 - l * math.sin(a1))
pt3 = (x0 + l * math.cos(a2), y0 - l * math.sin(a2)) # x0+左下點旋轉后在水平線上的投影, y0-左下點在垂直線上的投影,顯然逆時針轉時,左下點上一和左移了。
pt4 = (x0 - l * math.cos(a1), y0 + l * math.sin(a1))
else:
a1 = angle + math.atan(h / float(w))
a2 = angle - math.atan(h / float(w))
pt1 = (x0 - l * math.cos(a1), y0 - l * math.sin(a1))
pt2 = (x0 + l * math.cos(a2), y0 + l * math.sin(a2))
pt3 = (x0 + l * math.cos(a1), y0 + l * math.sin(a1))
pt4 = (x0 - l * math.cos(a2), y0 - l * math.sin(a2))
return [pt1[0], pt1[1], pt2[0], pt2[1], pt3[0], pt3[1], pt4[0], pt4[1]]
def read_file(path):
result = []
for line in open(path):
info = []
data = line.split(' ')
info.append(int(data[2]))
info.append(int(data[3]))
info.append(int(data[4]))
info.append(int(data[5]))
info.append(float(data[6]))
info.append(data[0])
result.append(info)
return result
if __name__ == '__main__':
file_path = '/home/ljs/OCR_dataset/MSRA-TD500/test/'
save_img_path = '../dataset/OCR_dataset/ctpn/test_im/'
save_gt_path = '../dataset/OCR_dataset/ctpn/test_gt/'
file_list = os.listdir(file_path)
for f in file_list:
if '.gt' in f:
continue
name = f[0:8]
txt_path = file_path + name + '.gt'
im_path = file_path + f
im = cv2.imread(im_path)
coordinate = read_file(txt_path)
# 仿照ICDAR格式,圖片名字寫做img_xx.jpg,對應的標簽文件寫做gt_img_xx.txt
cv2.imwrite(save_img_path + name.lower() + '.jpg', im)
save_gt = open(save_gt_path + 'gt_' + name.lower() + '.txt', 'w')
for i in coordinate:
box = get_box_img(i[0], i[1], i[2], i[3], i[4])
box = [int(box[i]) for i in range(len(box))]
box = [str(box[i]) for i in range(len(box))]
save_gt.write(','.join(box))
save_gt.write('\n')
經過格式處理后,我們兩份數據集算是整理好了。當然我們還需要對整個數據集划分為訓練集和測試集,我的文件組織習慣如下:train_im, test_im文件夾裝的是訓練和測試圖像,train_gt和test_gt裝的是訓練和測試標簽。
訓練標簽生成
因為CTPN的核心思想也是基於Faster RCNN中的region proposal機制的,所以原始數據標簽需要轉化為
anchor標簽。訓練數據的標簽的生成的代碼是最難寫,因為從一個完整的文本框標簽轉化為一個個小尺度文本框標簽確實有點難度,而且這個anchor標簽的生成方式也與Faster RCNN生成方式略有不同。下面講一講我的實現思路:
第一步我們需要將原先每張圖的bbox標簽轉化為每個anchor標簽。為了實現該功能,我們先將一張圖划分為寬度為16的各個anchor。
- 首先計算一張圖可以分為多少個寬度為16的acnhor(比如一張圖的寬度為w,那么水平anchor總數為w/16),再計算出我們的文本框標簽中含有幾個acnhor,最左和最右的anchor又是哪幾個;
- 計算文本框內anchor的高度和中心是多少:此時我們可以在一個全黑的mask中把文本框label畫上去(白色),然后從上往下和從下往上找到第一個白色像素點的位置作為該anchor的上下邊界;
- 最后將每個anchor的位置(水平ID)、anchor中心y坐標、anchor高度存儲並返回
def generate_gt_anchor(img, box, anchor_width=16):
"""
calsulate ground truth fine-scale box
:param img: input image
:param box: ground truth box (4 point)
:param anchor_width:
:return: tuple (position, h, cy)
"""
if not isinstance(box[0], float):
box = [float(box[i]) for i in range(len(box))]
result = []
# 求解一個bbox下,能分解為多少個16寬度的小anchor,並求出最左和最右的小achor的id
left_anchor_num = int(math.floor(max(min(box[0], box[6]), 0) / anchor_width)) # the left side anchor of the text box, downwards
right_anchor_num = int(math.ceil(min(max(box[2], box[4]), img.shape[1]) / anchor_width)) # the right side anchor of the text box, upwards
# handle extreme case, the right side anchor may exceed the image width
if right_anchor_num * 16 + 15 > img.shape[1]:
right_anchor_num -= 1
# combine the left-side and the right-side x_coordinate of a text anchor into one pair
position_pair = [(i * anchor_width, (i + 1) * anchor_width - 1) for i in range(left_anchor_num, right_anchor_num)]
# 計算每個gt anchor的真實位置,其實就是求解gt anchor的上邊界和下邊界
y_top, y_bottom = cal_y_top_and_bottom(img, position_pair, box)
# 最后將每個anchor的位置(水平ID)、anchor中心y坐標、anchor高度存儲並返回
for i in range(len(position_pair)):
position = int(position_pair[i][0] / anchor_width) # the index of anchor box
h = y_bottom[i] - y_top[i] + 1 # the height of anchor box
cy = (float(y_bottom[i]) + float(y_top[i])) / 2.0 # the center point of anchor box
result.append((position, cy, h))
return result
計算anchor上下邊界的方法:
# cal the gt anchor box's bottom and top coordinate
def cal_y_top_and_bottom(raw_img, position_pair, box):
"""
:param raw_img:
:param position_pair: for example:[(0, 15), (16, 31), ...]
:param box: gt box (4 point)
:return: top and bottom coordinates for y-axis
"""
img = copy.deepcopy(raw_img)
y_top = []
y_bottom = []
height = img.shape[0]
# 設置圖像mask,channel 0為全黑圖
for i in range(img.shape[0]):
for j in range(img.shape[1]):
img[i, j, 0] = 0
top_flag = False
bottom_flag = False
# 根據bbox四點畫出文本框,channel 0下文本框為白色
img = other.draw_box_4pt(img, box, color=(255, 0, 0))
for k in range(len(position_pair)):
# 從左到右遍歷anchor gt,對每個anchor從上往下掃描像素,遇到白色像素點(255)就停下來,此時像素點坐標y就是該anchor gt的上邊界
# calc top y coordinate
for y in range(0, height-1):
# loop each anchor, from left to right
for x in range(position_pair[k][0], position_pair[k][1] + 1):
if img[y, x, 0] == 255:
y_top.append(y)
top_flag = True
break
if top_flag is True:
break
# 從左到右遍歷anchor gt,對每個anchor從下往上掃描像素,遇到白色像素點(255)就停下來,此時像素點坐標y就是該anchor gt的下邊界
# calc bottom y coordinate, pixel from down to top loop
for y in range(height - 1, -1, -1):
# loop each anchor, from left to right
for x in range(position_pair[k][0], position_pair[k][1] + 1):
if img[y, x, 0] == 255:
y_bottom.append(y)
bottom_flag = True
break
if bottom_flag is True:
break
top_flag = False
bottom_flag = False
return y_top, y_bottom
經過上面的標簽處理,我們已經將原先的標准的文本框標簽轉化為一個一個小尺度anchor標簽,以下是標簽轉化后的效果:
以上標簽可視化后看來anchor標簽做得不錯,但是這里需要提出的是,我發現這種anchor生成方法是不太精准的,比如一個文本框邊緣像素剛好落在一個新的anchor上,那么我們就要為這個像素分配一個16像素的anchor,顯然導致了文本框標簽的不准確,引入了15像素的誤差,這個是需要思考的。這個問題我們先不做處理,繼續下面的工作。
當然轉化期間我們也遇到很多奇怪的問題,比如下圖這種標簽都已經超出圖像范圍的,我們必須做相應的特殊處理,比如限定標簽橫坐標的最大尺寸為圖像寬度。
left_anchor_num = int(math.floor(max(min(box[0], box[6]), 0) / anchor_width)) # the left side anchor of the text box, downwards
right_anchor_num = int(math.ceil(min(max(box[2], box[4]), img.shape[1]) / anchor_width)) # the right side anchor of the text box, upwards
CTPN網絡結構
因為CTPN用到了CNN+雙向LSTM的網絡結構,所以我們分步實現CTPN架構。
CNN部分CTPN采取了VGG16進行底層特征提取。
class VGG_16(nn.Module):
"""
VGG-16 without pooling layer before fc layer
"""
def __init__(self):
super(VGG_16, self).__init__()
self.convolution1_1 = nn.Conv2d(3, 64, 3, padding=1)
self.convolution1_2 = nn.Conv2d(64, 64, 3, padding=1)
self.pooling1 = nn.MaxPool2d(2, stride=2)
self.convolution2_1 = nn.Conv2d(64, 128, 3, padding=1)
self.convolution2_2 = nn.Conv2d(128, 128, 3, padding=1)
self.pooling2 = nn.MaxPool2d(2, stride=2)
self.convolution3_1 = nn.Conv2d(128, 256, 3, padding=1)
self.convolution3_2 = nn.Conv2d(256, 256, 3, padding=1)
self.convolution3_3 = nn.Conv2d(256, 256, 3, padding=1)
self.pooling3 = nn.MaxPool2d(2, stride=2)
self.convolution4_1 = nn.Conv2d(256, 512, 3, padding=1)
self.convolution4_2 = nn.Conv2d(512, 512, 3, padding=1)
self.convolution4_3 = nn.Conv2d(512, 512, 3, padding=1)
self.pooling4 = nn.MaxPool2d(2, stride=2)
self.convolution5_1 = nn.Conv2d(512, 512, 3, padding=1)
self.convolution5_2 = nn.Conv2d(512, 512, 3, padding=1)
self.convolution5_3 = nn.Conv2d(512, 512, 3, padding=1)
def forward(self, x):
x = F.relu(self.convolution1_1(x), inplace=True)
x = F.relu(self.convolution1_2(x), inplace=True)
x = self.pooling1(x)
x = F.relu(self.convolution2_1(x), inplace=True)
x = F.relu(self.convolution2_2(x), inplace=True)
x = self.pooling2(x)
x = F.relu(self.convolution3_1(x), inplace=True)
x = F.relu(self.convolution3_2(x), inplace=True)
x = F.relu(self.convolution3_3(x), inplace=True)
x = self.pooling3(x)
x = F.relu(self.convolution4_1(x), inplace=True)
x = F.relu(self.convolution4_2(x), inplace=True)
x = F.relu(self.convolution4_3(x), inplace=True)
x = self.pooling4(x)
x = F.relu(self.convolution5_1(x), inplace=True)
x = F.relu(self.convolution5_2(x), inplace=True)
x = F.relu(self.convolution5_3(x), inplace=True)
return x
再實現雙向LSTM,增強關聯序列的信息學習。
class BLSTM(nn.Module):
def __init__(self, channel, hidden_unit, bidirectional=True):
"""
:param channel: lstm input channel num
:param hidden_unit: lstm hidden unit
:param bidirectional:
"""
super(BLSTM, self).__init__()
self.lstm = nn.LSTM(channel, hidden_unit, bidirectional=bidirectional)
def forward(self, x):
"""
WARNING: The batch size of x must be 1.
"""
x = x.transpose(1, 3)
recurrent, _ = self.lstm(x[0])
recurrent = recurrent[np.newaxis, :, :, :]
recurrent = recurrent.transpose(1, 3)
return recurrent
這里實現多一層中間層,用於連接CNN和LSTM。將VGG最后一層卷積層輸出的feature map轉化為向量形式,用於接下來的LSTM訓練。
class Im2col(nn.Module):
def __init__(self, kernel_size, stride, padding):
super(Im2col, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
def forward(self, x):
height = x.shape[2]
x = F.unfold(x, self.kernel_size, padding=self.padding, stride=self.stride)
x = x.reshape((x.shape[0], x.shape[1], height, -1))
return x
最后將以上三部分拼接成一個完整的CTPN網絡:底層使用VGG16做特征提取->lstm序列信息學習->output每個anchor分數,h, y, side_refinement
class CTPN(nn.Module):
def __init__(self):
super(CTPN, self).__init__()
self.cnn = nn.Sequential()
self.cnn.add_module('VGG_16', VGG_16())
self.rnn = nn.Sequential()
self.rnn.add_module('im2col', Net.Im2col((3, 3), (1, 1), (1, 1)))
self.rnn.add_module('blstm', BLSTM(3 * 3 * 512, 128))
self.FC = nn.Conv2d(256, 512, 1)
self.vertical_coordinate = nn.Conv2d(512, 2 * 10, 1) # 最終輸出2K個參數(k=10),10表示anchor的尺寸個數,2個參數分別表示anchor的h和dy
self.score = nn.Conv2d(512, 2 * 10, 1) # 最終輸出是2K個分數(k=10),2表示有無字符,10表示anchor的尺寸個數
self.side_refinement = nn.Conv2d(512, 10, 1) # 最終輸出1K個參數(k=10),該參數表示該anchor的水平偏移,用於精修文本框水平邊緣精度,,10表示anchor的尺寸個數
def forward(self, x, val=False):
x = self.cnn(x)
x = self.rnn(x)
x = self.FC(x)
x = F.relu(x, inplace=True)
vertical_pred = self.vertical_coordinate(x)
score = self.score(x)
if val:
score = score.reshape((score.shape[0], 10, 2, score.shape[2], score.shape[3]))
score = score.squeeze(0)
score = score.transpose(1, 2)
score = score.transpose(2, 3)
score = score.reshape((-1, 2))
#score = F.softmax(score, dim=1)
score = score.reshape((10, vertical_pred.shape[2], -1, 2))
vertical_pred = vertical_pred.reshape((vertical_pred.shape[0], 10, 2, vertical_pred.shape[2], vertical_pred.shape[3]))
side_refinement = self.side_refinement(x)
return vertical_pred, score, side_refinement
損失函數設計
CTPN的LOSS分為三部分:
- h,y的regression loss,用的是SmoothL1Loss;
- score的classification loss,用的是CrossEntropyLoss;
- side refinement loss,用的是用的是SmoothL1Loss。
先定義好一些固定參數
class CTPN_Loss(nn.Module):
def __init__(self, using_cuda=False):
super(CTPN_Loss, self).__init__()
self.Ns = 128
self.ratio = 0.5
self.lambda1 = 1.0
self.lambda2 = 1.0
self.Ls_cls = nn.CrossEntropyLoss()
self.Lv_reg = nn.SmoothL1Loss()
self.Lo_reg = nn.SmoothL1Loss()
self.using_cuda = using_cuda
首先設計classification loss
cls_loss = 0.0
if self.using_cuda:
for p in positive_batch:
cls_loss += self.Ls_cls(score[0, p[2] * 2: ((p[2] + 1) * 2), p[1], p[0]].unsqueeze(0),
torch.LongTensor([1]).cuda())
for n in negative_batch:
cls_loss += self.Ls_cls(score[0, n[2] * 2: ((n[2] + 1) * 2), n[1], n[0]].unsqueeze(0),
torch.LongTensor([0]).cuda())
else:
for p in positive_batch:
cls_loss += self.Ls_cls(score[0, p[2] * 2: ((p[2] + 1) * 2), p[1], p[0]].unsqueeze(0),
torch.LongTensor([1]))
for n in negative_batch:
cls_loss += self.Ls_cls(score[0, n[2] * 2: ((n[2] + 1) * 2), n[1], n[0]].unsqueeze(0),
torch.LongTensor([0]))
cls_loss = cls_loss / self.Ns
然后是vertical coordinate regression loss,反映的是y和h的偏差
# calculate vertical coordinate regression loss
v_reg_loss = 0.0
Nv = len(vertical_reg)
if self.using_cuda:
for v in vertical_reg:
v_reg_loss += self.Lv_reg(vertical_pred[0, v[2] * 2: ((v[2] + 1) * 2), v[1], v[0]].unsqueeze(0),
torch.FloatTensor([v[3], v[4]]).unsqueeze(0).cuda())
else:
for v in vertical_reg:
v_reg_loss += self.Lv_reg(vertical_pred[0, v[2] * 2: ((v[2] + 1) * 2), v[1], v[0]].unsqueeze(0),
torch.FloatTensor([v[3], v[4]]).unsqueeze(0))
v_reg_loss = v_reg_loss / float(Nv)
最后計算side refinement regression loss,用於修正邊緣精度
# calculate side refinement regression loss
o_reg_loss = 0.0
No = len(side_refinement_reg)
if self.using_cuda:
for s in side_refinement_reg:
o_reg_loss += self.Lo_reg(side_refinement[0, s[2]: s[2] + 1, s[1], s[0]].unsqueeze(0),
torch.FloatTensor([s[3]]).unsqueeze(0).cuda())
else:
for s in side_refinement_reg:
o_reg_loss += self.Lo_reg(side_refinement[0, s[2]: s[2] + 1, s[1], s[0]].unsqueeze(0),
torch.FloatTensor([s[3]]).unsqueeze(0))
o_reg_loss = o_reg_loss / float(No)
當然最后還有個total loss,匯總整個訓練過程中的loss
loss = cls_loss + v_reg_loss * self.lambda1 + o_reg_loss * self.lambda2
訓練過程設計
訓練:優化器我們選擇SGD,learning rate我們設置了兩個,前N個epoch使用較大的lr,后面的epoch使用較小的lr以更好地收斂。訓練過程我們定義了4個loss,分別是total_cls_loss,total_v_reg_loss, total_o_reg_loss, total_loss(前面三個loss相加)。
net = Net.CTPN() # 獲取網絡結構
for name, value in net.named_parameters():
if name in no_grad:
value.requires_grad = False
else:
value.requires_grad = True
# for name, value in net.named_parameters():
# print('name: {0}, grad: {1}'.format(name, value.requires_grad))
net.load_state_dict(torch.load('./lib/vgg16.model'))
# net.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
lib.utils.init_weight(net)
if using_cuda:
net.cuda()
net.train()
print(net)
criterion = Loss.CTPN_Loss(using_cuda=using_cuda) # 獲取loss
train_im_list, train_gt_list, val_im_list, val_gt_list = create_train_val() # 獲取訓練、測試數據
total_iter = len(train_im_list)
print("total training image num is %s" % len(train_im_list))
print("total val image num is %s" % len(val_im_list))
train_loss_list = []
test_loss_list = []
# 開始迭代訓練
for i in range(epoch):
if i >= change_epoch:
lr = lr_behind
else:
lr = lr_front
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
#optimizer = optim.Adam(net.parameters(), lr=lr)
iteration = 1
total_loss = 0
total_cls_loss = 0
total_v_reg_loss = 0
total_o_reg_loss = 0
start_time = time.time()
random.shuffle(train_im_list) # 打亂訓練集
# print(random_im_list)
for im in train_im_list:
root, file_name = os.path.split(im)
root, _ = os.path.split(root)
name, _ = os.path.splitext(file_name)
gt_name = 'gt_' + name + '.txt'
gt_path = os.path.join(root, "train_gt", gt_name)
if not os.path.exists(gt_path):
print('Ground truth file of image {0} not exists.'.format(im))
continue
gt_txt = lib.dataset_handler.read_gt_file(gt_path) # 讀取對應的標簽
#print("processing image %s" % os.path.join(img_root1, im))
img = cv2.imread(im)
if img is None:
iteration += 1
continue
img, gt_txt = lib.dataset_handler.scale_img(img, gt_txt) # 圖像和標簽做歸一化
tensor_img = img[np.newaxis, :, :, :]
tensor_img = tensor_img.transpose((0, 3, 1, 2))
if using_cuda:
tensor_img = torch.FloatTensor(tensor_img).cuda()
else:
tensor_img = torch.FloatTensor(tensor_img)
vertical_pred, score, side_refinement = net(tensor_img) # 正向計算,獲取預測結果
del tensor_img
# transform bbox gt to anchor gt for training
positive = []
negative = []
vertical_reg = []
side_refinement_reg = []
visual_img = copy.deepcopy(img) # 該圖用於可視化標簽
try:
# loop all bbox in one image
for box in gt_txt:
# generate anchors from one bbox
gt_anchor, visual_img = lib.generate_gt_anchor.generate_gt_anchor(img, box, draw_img_gt=visual_img) # 獲取圖像的anchor標簽
positive1, negative1, vertical_reg1, side_refinement_reg1 = lib.tag_anchor.tag_anchor(gt_anchor, score, box) # 計算預測值反映在anchor層面的數據
positive += positive1
negative += negative1
vertical_reg += vertical_reg1
side_refinement_reg += side_refinement_reg1
except:
print("warning: img %s raise error!" % im)
iteration += 1
continue
if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0:
iteration += 1
continue
cv2.imwrite(os.path.join(DRAW_PREFIX, file_name), visual_img)
optimizer.zero_grad()
# 計算誤差
loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive,
negative, vertical_reg, side_refinement_reg)
# 反向傳播
loss.backward()
optimizer.step()
iteration += 1
# save gpu memory by transferring loss to float
total_loss += float(loss)
total_cls_loss += float(cls_loss)
total_v_reg_loss += float(v_reg_loss)
total_o_reg_loss += float(o_reg_loss)
if iteration % display_iter == 0:
end_time = time.time()
total_time = end_time - start_time
print('Epoch: {2}/{3}, Iteration: {0}/{1}, loss: {4}, cls_loss: {5}, v_reg_loss: {6}, o_reg_loss: {7}, {8}'.
format(iteration, total_iter, i, epoch, total_loss / display_iter, total_cls_loss / display_iter,
total_v_reg_loss / display_iter, total_o_reg_loss / display_iter, im))
logger.info('Epoch: {2}/{3}, Iteration: {0}/{1}'.format(iteration, total_iter, i, epoch))
logger.info('loss: {0}'.format(total_loss / display_iter))
logger.info('classification loss: {0}'.format(total_cls_loss / display_iter))
logger.info('vertical regression loss: {0}'.format(total_v_reg_loss / display_iter))
logger.info('side-refinement regression loss: {0}'.format(total_o_reg_loss / display_iter))
train_loss_list.append(total_loss)
total_loss = 0
total_cls_loss = 0
total_v_reg_loss = 0
total_o_reg_loss = 0
start_time = time.time()
# 定期驗證模型性能
if iteration % val_iter == 0:
net.eval()
logger.info('Start evaluate at {0} epoch {1} iteration.'.format(i, iteration))
val_loss = evaluate.val(net, criterion, val_batch_size, using_cuda, logger, val_im_list)
logger.info('End evaluate.')
net.train()
start_time = time.time()
test_loss_list.append(val_loss)
# 定期存儲模型
if iteration % save_iter == 0:
print('Model saved at ./model/ctpn-{0}-{1}.model'.format(i, iteration))
torch.save(net.state_dict(), './model/ctpn-msra_ali-{0}-{1}.model'.format(i, iteration))
print('Model saved at ./model/ctpn-{0}-end.model'.format(i))
torch.save(net.state_dict(), './model/ctpn-msra_ali-{0}-end.model'.format(i))
# 畫出loss的變化圖
draw_loss_plot(train_loss_list, test_loss_list)
縮放圖像具有一定規則:首先要保證文本框label的最短邊也要等於600。我們通過scale = float(shortest_side)/float(min(height, width))
來求得圖像的縮放系數,對原始圖像進行縮放。同時我們也要對我們的label也要根據該縮放系數進行縮放。
def scale_img(img, gt, shortest_side=600):
height = img.shape[0]
width = img.shape[1]
scale = float(shortest_side)/float(min(height, width))
img = cv2.resize(img, (0, 0), fx=scale, fy=scale)
if img.shape[0] < img.shape[1] and img.shape[0] != 600:
img = cv2.resize(img, (600, img.shape[1]))
elif img.shape[0] > img.shape[1] and img.shape[1] != 600:
img = cv2.resize(img, (img.shape[0], 600))
elif img.shape[0] != 600:
img = cv2.resize(img, (600, 600))
h_scale = float(img.shape[0])/float(height)
w_scale = float(img.shape[1])/float(width)
scale_gt = []
for box in gt:
scale_box = []
for i in range(len(box)):
# x坐標
if i % 2 == 0:
scale_box.append(int(int(box[i]) * w_scale))
# y坐標
else:
scale_box.append(int(int(box[i]) * h_scale))
scale_gt.append(scale_box)
return img, scale_gt
驗證集評估:
def val(net, criterion, batch_num, using_cuda, logger):
img_root = '../dataset/OCR_dataset/ctpn/test_im'
gt_root = '../dataset/OCR_dataset/ctpn/test_gt'
img_list = os.listdir(img_root)
total_loss = 0
total_cls_loss = 0
total_v_reg_loss = 0
total_o_reg_loss = 0
start_time = time.time()
for im in random.sample(img_list, batch_num):
name, _ = os.path.splitext(im)
gt_name = 'gt_' + name + '.txt'
gt_path = os.path.join(gt_root, gt_name)
if not os.path.exists(gt_path):
print('Ground truth file of image {0} not exists.'.format(im))
continue
gt_txt = Dataset.port.read_gt_file(gt_path, have_BOM=True)
img = cv2.imread(os.path.join(img_root, im))
img, gt_txt = Dataset.scale_img(img, gt_txt)
tensor_img = img[np.newaxis, :, :, :]
tensor_img = tensor_img.transpose((0, 3, 1, 2))
if using_cuda:
tensor_img = torch.FloatTensor(tensor_img).cuda()
else:
tensor_img = torch.FloatTensor(tensor_img)
vertical_pred, score, side_refinement = net(tensor_img)
del tensor_img
positive = []
negative = []
vertical_reg = []
side_refinement_reg = []
for box in gt_txt:
gt_anchor = Dataset.generate_gt_anchor(img, box)
positive1, negative1, vertical_reg1, side_refinement_reg1 = Net.tag_anchor(gt_anchor, score, box)
positive += positive1
negative += negative1
vertical_reg += vertical_reg1
side_refinement_reg += side_refinement_reg1
if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0:
batch_num -= 1
continue
loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive,
negative, vertical_reg, side_refinement_reg)
total_loss += loss
total_cls_loss += cls_loss
total_v_reg_loss += v_reg_loss
total_o_reg_loss += o_reg_loss
end_time = time.time()
total_time = end_time - start_time
print('#################### Start evaluate ####################')
print('loss: {0}'.format(total_loss / float(batch_num)))
logger.info('Evaluate loss: {0}'.format(total_loss / float(batch_num)))
print('classification loss: {0}'.format(total_cls_loss / float(batch_num)))
logger.info('Evaluate vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num)))
print('vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num)))
logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num)))
print('side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num)))
logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num)))
print('{1} iterations for {0} seconds.'.format(total_time, batch_num))
print('##################### Evaluate end #####################')
print('\n')
訓練過程:
訓練效果與預測效果
測試效果:輸入一張圖片,給出最后的檢測結果
def infer_one(im_name, net):
im = cv2.imread(im_name)
im = lib.dataset_handler.scale_img_only(im) # 歸一化圖像
img = copy.deepcopy(im)
img = img.transpose(2, 0, 1)
img = img[np.newaxis, :, :, :]
img = torch.Tensor(img)
v, score, side = net(img, val=True) # 送入網絡預測
result = []
# 根據分數獲取有文字的anchor
for i in range(score.shape[0]):
for j in range(score.shape[1]):
for k in range(score.shape[2]):
if score[i, j, k, 1] > THRESH_HOLD:
result.append((j, k, i, float(score[i, j, k, 1].detach().numpy())))
# nms過濾
for_nms = []
for box in result:
pt = lib.utils.trans_to_2pt(box[1], box[0] * 16 + 7.5, anchor_height[box[2]])
for_nms.append([pt[0], pt[1], pt[2], pt[3], box[3], box[0], box[1], box[2]])
for_nms = np.array(for_nms, dtype=np.float32)
nms_result = lib.nms.cpu_nms(for_nms, NMS_THRESH)
out_nms = []
for i in nms_result:
out_nms.append(for_nms[i, 0:8])
# 確定哪幾個anchors是屬於一組的
connect = get_successions(v, out_nms)
# 將一組anchors合並成一條文本線
texts = get_text_lines(connect, im.shape)
for box in texts:
box = np.array(box)
print(box)
lib.draw_image.draw_ploy_4pt(im, box[0:8])
_, basename = os.path.split(im_name)
cv2.imwrite('./infer_'+basename, im)
推斷時提到了get_successions
用於獲取一個預測文本行里的所有anchors,換句話說,我們得到的很多預測有字符的anchor,但是我們怎么知道哪些acnhors可以組成一個文本線呢?所以我們需要實現一個anchor合並算法,這也是CTPN代碼實現中最為困難的一步。
CTPN論文提到,文本線構造法如下:文本行構建很簡單,通過將那些text/no-text score > 0.7的連續的text proposals相連接即可。文本行的構建如下。
- 首先,為一個proposal Bi定義一個鄰居(Bj):Bj−>Bi,其中:
- Bj在水平距離上離Bi最近
- 該距離小於50 pixels
- 它們的垂直重疊(vertical overlap) > 0.7
一看理論很簡單,但是一到自己實現就困難重重了。真是應了那句“紙上得來終覺淺,絕知此事要躬行”啊!get_successions
傳入的參數是v代表每個預測anchor的h和y信息,anchors代表每個anchors的四個頂點坐標信息。
def get_successions(v, anchors=[]):
texts = []
for i, anchor in enumerate(anchors):
neighbours = [] # 記錄每組的anchors
neighbours.append(i)
center_x1 = (anchor[2] + anchor[0]) / 2
h1 = get_anchor_h(anchor, v) # 獲取該anchor的高度
# find i's neighbour
# 遍歷余下的anchors,找出鄰居
for j in range(i + 1, len(anchors)):
center_x2 = (anchors[j][2] + anchors[j][0]) / 2 # 中心點X坐標
h2 = get_anchor_h(anchors[j], v)
# 如果這兩個Anchor間的距離小於50,而且他們的它們的垂直重疊(vertical overlap)大於一定閾值,那就是鄰居
if abs(center_x1 - center_x2) < NEIGHBOURS_MIN_DIST and \
meet_v_iou(max(anchor[1], anchors[j][1]), min(anchor[3], anchors[j][3]), h1, h2): # less than 50 pixel between each anchor
neighbours.append(j)
if len(neighbours) != 0:
texts.append(neighbours)
# 通過上面的步驟,我們已經把每一個anchor的鄰居都找到並加入了對應的集合中了,現在我們
# 通過一個循環來不斷將每個小組合並
need_merge = True
while need_merge:
need_merge = False
# ok, we combine again.
for i, line in enumerate(texts):
if len(line) == 0:
continue
for index in line:
for j in range(i+1, len(texts)):
if index in texts[j]:
texts[i] += texts[j]
texts[i] = list(set(texts[i]))
texts[j] = []
need_merge = True
result = []
#print(texts)
for text in texts:
if len(text) < 2:
continue
local = []
for j in text:
local.append(anchors[j])
result.append(local)
return result
當我們得到一個文本框的anchors組合后,接下來要做的就是將組內的anchors串聯成一個文本框。get_text_lines
函數做的就是這個功能。
def get_text_lines(text_proposals, im_size, scores=0):
"""
text_proposals:boxes
"""
text_lines = np.zeros((len(text_proposals), 8), np.float32)
for index, tp_indices in enumerate(text_proposals):
text_line_boxes = np.array(tp_indices) # 每個文本行的全部小框
#print(text_line_boxes)
#print(type(text_line_boxes))
#print(text_line_boxes.shape)
X = (text_line_boxes[:, 0] + text_line_boxes[:, 2]) / 2 # 求每一個小框的中心x,y坐標
Y = (text_line_boxes[:, 1] + text_line_boxes[:, 3]) / 2
#print(X)
#print(Y)
z1 = np.polyfit(X, Y, 1) # 多項式擬合,根據之前求的中心店擬合一條直線(最小二乘)
x0 = np.min(text_line_boxes[:, 0]) # 文本行x坐標最小值
x1 = np.max(text_line_boxes[:, 2]) # 文本行x坐標最大值
offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5 # 小框寬度的一半
# 以全部小框的左上角這個點去擬合一條直線,然后計算一下文本行x坐標的極左極右對應的y坐標
lt_y, rt_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset)
# 以全部小框的左下角這個點去擬合一條直線,然后計算一下文本行x坐標的極左極右對應的y坐標
lb_y, rb_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset)
#score = scores[list(tp_indices)].sum() / float(len(tp_indices)) # 求全部小框得分的均值作為文本行的均值
text_lines[index, 0] = x0
text_lines[index, 1] = min(lt_y, rt_y) # 文本行上端 線段 的y坐標的小值
text_lines[index, 2] = x1
text_lines[index, 3] = max(lb_y, rb_y) # 文本行下端 線段 的y坐標的大值
text_lines[index, 4] = scores # 文本行得分
text_lines[index, 5] = z1[0] # 根據中心點擬合的直線的k,b
text_lines[index, 6] = z1[1]
height = np.mean((text_line_boxes[:, 3] - text_line_boxes[:, 1])) # 小框平均高度
text_lines[index, 7] = height + 2.5
text_recs = np.zeros((len(text_lines), 9), np.float32)
index = 0
for line in text_lines:
b1 = line[6] - line[7] / 2 # 根據高度和文本行中心線,求取文本行上下兩條線的b值
b2 = line[6] + line[7] / 2
x1 = line[0]
y1 = line[5] * line[0] + b1 # 左上
x2 = line[2]
y2 = line[5] * line[2] + b1 # 右上
x3 = line[0]
y3 = line[5] * line[0] + b2 # 左下
x4 = line[2]
y4 = line[5] * line[2] + b2 # 右下
disX = x2 - x1
disY = y2 - y1
width = np.sqrt(disX * disX + disY * disY) # 文本行寬度
fTmp0 = y3 - y1 # 文本行高度
fTmp1 = fTmp0 * disY / width
x = np.fabs(fTmp1 * disX / width) # 做補償
y = np.fabs(fTmp1 * disY / width)
if line[5] < 0:
x1 -= x
y1 += y
x4 += x
y4 -= y
else:
x2 += x
y2 += y
x3 -= x
y3 -= y
# clock-wise order
text_recs[index, 0] = x1
text_recs[index, 1] = y1
text_recs[index, 2] = x2
text_recs[index, 3] = y2
text_recs[index, 4] = x4
text_recs[index, 5] = y4
text_recs[index, 6] = x3
text_recs[index, 7] = y3
text_recs[index, 8] = line[4]
index = index + 1
text_recs = clip_boxes(text_recs, im_size)
return text_recs
檢測效果和總結
首先看一下訓練出來的模型的文字檢測效果,為了便於觀察,我把anchor和最終合並好的文本框一並畫出:
下面再看看一些比較好的文字檢測效果吧:
在實現過程中的一些總結和想法:
- CTPN對於帶旋轉角度的文本的檢測效果不好,其實這是CTPN的算法特點決定的:一個個固定寬度的四邊形是很難合並出一個准確的文本框,比如一些anchors很難組成一組,即使組成一組了也很難精確恢復成完整的精確的文本矩形框(推斷階段的缺點)。當然啦,對於水平排布的文本檢測,個人認為這個算法思路還是很奏效的。
- CTPN中的side-refinement其實作用不大,如果我們檢測出來的文本是直接拿出識別,這個side-refinement優化的幾個像素差別其實可以忽略;
- CTPN的中間步驟有點多:從anchor標簽的生成到中間計算loss再到最后推斷的文本線生成步驟,都會引入一定的誤差,這個缺點也是EAST論文中所提出的。訓練的步驟越簡潔,中間過程越少,精度更有保障。
- CTPN的算法得出的效果可以看出,准確率低但召回率高。這種基於16像素的anchor識別感覺對於一些大的非文字圖標(比如路標)誤判率相當高,這是源於其anchor的寬度實在太小了,盡管使用了lstm關聯周圍anchor,但是我還是認為有點“一葉障目”的感覺。所以CTPN對於過大或過小的文字檢測效果不會太好。
- CTPN是個比較老的算法了(2016年),其思路在當年還是很創新的,但是也有很多弊端。現在提出的新方法已經基本解決了這些不足之處,比如EAST,PixelNet都是一些很優秀的新算法。
CTPN的完整實現可以參考我的Github