訓練數據處理:
天池ICPR2018和MSRA_TD500兩個數據集:
1)天池ICPR的數據集為網絡圖像,都是一些淘寶商家上傳到淘寶的一些商品介紹圖像,其標簽方式參考了ICDAR2015的數據標簽格式,即一個文本框用4個坐標來表示,即左上、右上、右下、左下四個坐標,共八個值,記作[x1 y1 x2 y2 x3 y3 x4 y4]
2)MSRA_TD500使微軟收集的一個文本檢測和識別的一個數據集,里面的圖像多是街景圖,背景比較復雜,但文本位置比較明顯,一目了然。
因為MSRA_TD500的標簽格式不一樣,最后一個參數表示矩形框的旋轉角度。
所以我們第一步就是將這兩個數據集的標簽格式統一,我的做法是將MSRA數據集格式改為ICDAR格式,方便后面的模型訓練。
因為MSRA_TD500采取的標簽格式是[index difficulty_label x y w h angle],所以我們需要根據這個文本框的旋轉角度來求得水平文本框旋轉后的4個坐標位置。實現如下:
""" This file is to change MSRA_TD500 dataset format to ICDAR2015 dataset format. MSRA_TD500 format: [index difficulty_label x y w h angle] ICDAR2015 format: [left_top_x left_top_y right_top_X right_top_y right_bottom_x right_bottom_y left_bottom_x left_bottom_y] """ import math import cv2 import os # 求旋轉后矩形的4個坐標 def get_box_img(x, y, w, h, angle): # 矩形框中點(x0,y0) x0 = x + w/2 y0 = y + h/2 l = math.sqrt(pow(w/2, 2) + pow(h/2, 2)) # 即對角線的一半 # angle小於0,逆時針轉 if angle < 0: a1 = -angle + math.atan(h / float(w)) # 旋轉角度-對角線與底線所成的角度 a2 = -angle - math.atan(h / float(w)) # 旋轉角度+對角線與底線所成的角度 pt1 = (x0 - l * math.cos(a2), y0 + l * math.sin(a2)) pt2 = (x0 + l * math.cos(a1), y0 - l * math.sin(a1)) pt3 = (x0 + l * math.cos(a2), y0 - l * math.sin(a2)) # x0+左下點旋轉后在水平線上的投影, y0-左下點在垂直線上的投影,顯然逆時針轉時,左下點上一和左移了。 pt4 = (x0 - l * math.cos(a1), y0 + l * math.sin(a1)) else: a1 = angle + math.atan(h / float(w)) a2 = angle - math.atan(h / float(w)) pt1 = (x0 - l * math.cos(a1), y0 - l * math.sin(a1)) pt2 = (x0 + l * math.cos(a2), y0 + l * math.sin(a2)) pt3 = (x0 + l * math.cos(a1), y0 + l * math.sin(a1)) pt4 = (x0 - l * math.cos(a2), y0 - l * math.sin(a2)) return [pt1[0], pt1[1], pt2[0], pt2[1], pt3[0], pt3[1], pt4[0], pt4[1]] def read_file(path): result = [] for line in open(path): info = [] data = line.split(' ') info.append(int(data[2])) info.append(int(data[3])) info.append(int(data[4])) info.append(int(data[5])) info.append(float(data[6])) info.append(data[0]) result.append(info) return result if __name__ == '__main__': file_path = '/home/ljs/OCR_dataset/MSRA-TD500/test/' save_img_path = '../dataset/OCR_dataset/ctpn/test_im/' save_gt_path = '../dataset/OCR_dataset/ctpn/test_gt/' file_list = os.listdir(file_path) for f in file_list: if '.gt' in f: continue name = f[0:8] txt_path = file_path + name + '.gt' im_path = file_path + f im = cv2.imread(im_path) coordinate = read_file(txt_path) # 仿照ICDAR格式,圖片名字寫做img_xx.jpg,對應的標簽文件寫做gt_img_xx.txt cv2.imwrite(save_img_path + name.lower() + '.jpg', im) save_gt = open(save_gt_path + 'gt_' + name.lower() + '.txt', 'w') for i in coordinate: box = get_box_img(i[0], i[1], i[2], i[3], i[4]) box = [int(box[i]) for i in range(len(box))] box = [str(box[i]) for i in range(len(box))] save_gt.write(','.join(box)) save_gt.write('\n')
經過格式處理后,我們兩份數據集算是整理好了。當然我們還需要對整個數據集划分為訓練集和測試集,我的文件組織習慣如下:
train_im, test_im文件夾裝的是訓練和測試圖像,train_gt和test_gt裝的是訓練和測試標簽。
訓練標簽生成
因為CTPN的核心思想也是基於Faster RCNN中的region proposal機制的,所以原始數據標簽需要轉化為 anchor標簽。訓練數據的標簽的生成的代碼是最難寫,
因為從一個完整的文本框標簽轉化為一個個小尺度文本框標簽確實有點難度,而且這個anchor標簽的生成方式也與Faster RCNN生成方式略有不同。下面講一講我的實現思路:
第一步我們需要將原先每張圖的bbox標簽轉化為每個anchor標簽。為了實現該功能,我們先將一張圖划分為寬度為16的各個anchor。
- 首先計算一張圖可以分為多少個寬度為16的acnhor(比如一張圖的寬度為w,那么水平anchor總數為w/16),再計算出我們的文本框標簽中含有幾個acnhor,最左和最右的anchor又是哪幾個;
- 計算文本框內anchor的高度和中心是多少:此時我們可以在一個全黑的mask中把文本框label畫上去(白色),然后從上往下和從下往上找到第一個白色像素點的位置作為該anchor的上下邊界;
- 最后將每個anchor的位置(水平ID)、anchor中心y坐標、anchor高度存儲並返回
def generate_gt_anchor(img, box, anchor_width=16): """ calsulate ground truth fine-scale box :param img: input image :param box: ground truth box (4 point) :param anchor_width: :return: tuple (position, h, cy) """ if not isinstance(box[0], float): box = [float(box[i]) for i in range(len(box))] result = [] # 求解一個bbox下,能分解為多少個16寬度的小anchor,並求出最左和最右的小achor的id left_anchor_num = int(math.floor(max(min(box[0], box[6]), 0) / anchor_width)) # the left side anchor of the text box, downwards right_anchor_num = int(math.ceil(min(max(box[2], box[4]), img.shape[1]) / anchor_width)) # the right side anchor of the text box, upwards # handle extreme case, the right side anchor may exceed the image width if right_anchor_num * 16 + 15 > img.shape[1]: right_anchor_num -= 1 # combine the left-side and the right-side x_coordinate of a text anchor into one pair position_pair = [(i * anchor_width, (i + 1) * anchor_width - 1) for i in range(left_anchor_num, right_anchor_num)] # 計算每個gt anchor的真實位置,其實就是求解gt anchor的上邊界和下邊界 y_top, y_bottom = cal_y_top_and_bottom(img, position_pair, box) # 最后將每個anchor的位置(水平ID)、anchor中心y坐標、anchor高度存儲並返回 for i in range(len(position_pair)): position = int(position_pair[i][0] / anchor_width) # the index of anchor box h = y_bottom[i] - y_top[i] + 1 # the height of anchor box cy = (float(y_bottom[i]) + float(y_top[i])) / 2.0 # the center point of anchor box result.append((position, cy, h)) return result
計算anchor上下邊界的方法:
# cal the gt anchor box's bottom and top coordinate def cal_y_top_and_bottom(raw_img, position_pair, box): """ :param raw_img: :param position_pair: for example:[(0, 15), (16, 31), ...] :param box: gt box (4 point) :return: top and bottom coordinates for y-axis """ img = copy.deepcopy(raw_img) y_top = [] y_bottom = [] height = img.shape[0] # 設置圖像mask,channel 0為全黑圖 for i in range(img.shape[0]): for j in range(img.shape[1]): img[i, j, 0] = 0 top_flag = False bottom_flag = False # 根據bbox四點畫出文本框,channel 0下文本框為白色 img = other.draw_box_4pt(img, box, color=(255, 0, 0)) for k in range(len(position_pair)): # 從左到右遍歷anchor gt,對每個anchor從上往下掃描像素,遇到白色像素點(255)就停下來,此時像素點坐標y就是該anchor gt的上邊界 # calc top y coordinate for y in range(0, height-1): # loop each anchor, from left to right for x in range(position_pair[k][0], position_pair[k][1] + 1): if img[y, x, 0] == 255: y_top.append(y) top_flag = True break if top_flag is True: break # 從左到右遍歷anchor gt,對每個anchor從下往上掃描像素,遇到白色像素點(255)就停下來,此時像素點坐標y就是該anchor gt的下邊界 # calc bottom y coordinate, pixel from down to top loop for y in range(height - 1, -1, -1): # loop each anchor, from left to right for x in range(position_pair[k][0], position_pair[k][1] + 1): if img[y, x, 0] == 255: y_bottom.append(y) bottom_flag = True break if bottom_flag is True: break top_flag = False bottom_flag = False return y_top, y_bottom
經過上面的標簽處理,我們已經將原先的標准的文本框標簽轉化為一個一個小尺度anchor標簽,以下是標簽轉化后的效果:
以上標簽可視化后看來anchor標簽做得不錯,但是這里需要提出的是,我發現這種anchor生成方法是不太精准的,比如一個文本框邊緣像素剛好落在一個新的anchor上,那么我們就要為這個像素分配一個16像素的anchor,顯然導致了文本框標簽的不准確,引入了15像素的誤差,這個是需要思考的。這個問題我們先不做處理,繼續下面的工作。
當然轉化期間我們也遇到很多奇怪的問題,比如下圖這種標簽都已經超出圖像范圍的,我們必須做相應的特殊處理,比如限定標簽橫坐標的最大尺寸為圖像寬度。
left_anchor_num = int(math.floor(max(min(box[0], box[6]), 0) / anchor_width)) # the left side anchor of the text box, downwards right_anchor_num = int(math.ceil(min(max(box[2], box[4]), img.shape[1]) / anchor_width)) # the right side anchor of the text box, upwards
訓練過程:
練:優化器我們選擇SGD,learning rate我們設置了兩個,前N個epoch使用較大的lr,后面的epoch使用較小的lr以更好地收斂。
訓練過程我們定義了4個loss,分別是total_cls_loss,total_v_reg_loss, total_o_reg_loss, total_loss(前面三個loss相加)。
net = Net.CTPN() # 獲取網絡結構 for name, value in net.named_parameters(): if name in no_grad: value.requires_grad = False else: value.requires_grad = True # for name, value in net.named_parameters(): # print('name: {0}, grad: {1}'.format(name, value.requires_grad)) net.load_state_dict(torch.load('./lib/vgg16.model')) # net.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) lib.utils.init_weight(net) if using_cuda: net.cuda() net.train() print(net) criterion = Loss.CTPN_Loss(using_cuda=using_cuda) # 獲取loss train_im_list, train_gt_list, val_im_list, val_gt_list = create_train_val() # 獲取訓練、測試數據 total_iter = len(train_im_list) print("total training image num is %s" % len(train_im_list)) print("total val image num is %s" % len(val_im_list)) train_loss_list = [] test_loss_list = [] # 開始迭代訓練 for i in range(epoch): if i >= change_epoch: lr = lr_behind else: lr = lr_front optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) #optimizer = optim.Adam(net.parameters(), lr=lr) iteration = 1 total_loss = 0 total_cls_loss = 0 total_v_reg_loss = 0 total_o_reg_loss = 0 start_time = time.time() random.shuffle(train_im_list) # 打亂訓練集 # print(random_im_list) for im in train_im_list: root, file_name = os.path.split(im) root, _ = os.path.split(root) name, _ = os.path.splitext(file_name) gt_name = 'gt_' + name + '.txt' gt_path = os.path.join(root, "train_gt", gt_name) if not os.path.exists(gt_path): print('Ground truth file of image {0} not exists.'.format(im)) continue gt_txt = lib.dataset_handler.read_gt_file(gt_path) # 讀取對應的標簽 #print("processing image %s" % os.path.join(img_root1, im)) img = cv2.imread(im) if img is None: iteration += 1 continue img, gt_txt = lib.dataset_handler.scale_img(img, gt_txt) # 圖像和標簽做歸一化 tensor_img = img[np.newaxis, :, :, :] tensor_img = tensor_img.transpose((0, 3, 1, 2)) if using_cuda: tensor_img = torch.FloatTensor(tensor_img).cuda() else: tensor_img = torch.FloatTensor(tensor_img) vertical_pred, score, side_refinement = net(tensor_img) # 正向計算,獲取預測結果 del tensor_img # transform bbox gt to anchor gt for training positive = [] negative = [] vertical_reg = [] side_refinement_reg = [] visual_img = copy.deepcopy(img) # 該圖用於可視化標簽 try: # loop all bbox in one image for box in gt_txt: # generate anchors from one bbox gt_anchor, visual_img = lib.generate_gt_anchor.generate_gt_anchor(img, box, draw_img_gt=visual_img) # 獲取圖像的anchor標簽 positive1, negative1, vertical_reg1, side_refinement_reg1 = lib.tag_anchor.tag_anchor(gt_anchor, score, box) # 計算預測值反映在anchor層面的數據 positive += positive1 negative += negative1 vertical_reg += vertical_reg1 side_refinement_reg += side_refinement_reg1 except: print("warning: img %s raise error!" % im) iteration += 1 continue if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0: iteration += 1 continue cv2.imwrite(os.path.join(DRAW_PREFIX, file_name), visual_img) optimizer.zero_grad() # 計算誤差 loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive, negative, vertical_reg, side_refinement_reg) # 反向傳播 loss.backward() optimizer.step() iteration += 1 # save gpu memory by transferring loss to float total_loss += float(loss) total_cls_loss += float(cls_loss) total_v_reg_loss += float(v_reg_loss) total_o_reg_loss += float(o_reg_loss) if iteration % display_iter == 0: end_time = time.time() total_time = end_time - start_time print('Epoch: {2}/{3}, Iteration: {0}/{1}, loss: {4}, cls_loss: {5}, v_reg_loss: {6}, o_reg_loss: {7}, {8}'. format(iteration, total_iter, i, epoch, total_loss / display_iter, total_cls_loss / display_iter, total_v_reg_loss / display_iter, total_o_reg_loss / display_iter, im)) logger.info('Epoch: {2}/{3}, Iteration: {0}/{1}'.format(iteration, total_iter, i, epoch)) logger.info('loss: {0}'.format(total_loss / display_iter)) logger.info('classification loss: {0}'.format(total_cls_loss / display_iter)) logger.info('vertical regression loss: {0}'.format(total_v_reg_loss / display_iter)) logger.info('side-refinement regression loss: {0}'.format(total_o_reg_loss / display_iter)) train_loss_list.append(total_loss) total_loss = 0 total_cls_loss = 0 total_v_reg_loss = 0 total_o_reg_loss = 0 start_time = time.time() # 定期驗證模型性能 if iteration % val_iter == 0: net.eval() logger.info('Start evaluate at {0} epoch {1} iteration.'.format(i, iteration)) val_loss = evaluate.val(net, criterion, val_batch_size, using_cuda, logger, val_im_list) logger.info('End evaluate.') net.train() start_time = time.time() test_loss_list.append(val_loss) # 定期存儲模型 if iteration % save_iter == 0: print('Model saved at ./model/ctpn-{0}-{1}.model'.format(i, iteration)) torch.save(net.state_dict(), './model/ctpn-msra_ali-{0}-{1}.model'.format(i, iteration)) print('Model saved at ./model/ctpn-{0}-end.model'.format(i)) torch.save(net.state_dict(), './model/ctpn-msra_ali-{0}-end.model'.format(i)) # 畫出loss的變化圖 draw_loss_plot(train_loss_list, test_loss_list)
縮放圖像具有一定規則:首先要保證文本框label的最短邊也要等於600。我們通過
scale = float(shortest_side)/float(min(height, width))
來求得圖像的縮放系數,對原始圖像進行縮放。 同時我們也要對我們的label也要根據該縮放系數進行縮放。
def scale_img(img, gt, shortest_side=600): height = img.shape[0] width = img.shape[1] scale = float(shortest_side)/float(min(height, width)) img = cv2.resize(img, (0, 0), fx=scale, fy=scale) if img.shape[0] < img.shape[1] and img.shape[0] != 600: img = cv2.resize(img, (600, img.shape[1])) elif img.shape[0] > img.shape[1] and img.shape[1] != 600: img = cv2.resize(img, (img.shape[0], 600)) elif img.shape[0] != 600: img = cv2.resize(img, (600, 600)) h_scale = float(img.shape[0])/float(height) w_scale = float(img.shape[1])/float(width) scale_gt = [] for box in gt: scale_box = [] for i in range(len(box)): # x坐標 if i % 2 == 0: scale_box.append(int(int(box[i]) * w_scale)) # y坐標 else: scale_box.append(int(int(box[i]) * h_scale)) scale_gt.append(scale_box) return img, scale_gt
驗證集評估:
def val(net, criterion, batch_num, using_cuda, logger): img_root = '../dataset/OCR_dataset/ctpn/test_im' gt_root = '../dataset/OCR_dataset/ctpn/test_gt' img_list = os.listdir(img_root) total_loss = 0 total_cls_loss = 0 total_v_reg_loss = 0 total_o_reg_loss = 0 start_time = time.time() for im in random.sample(img_list, batch_num): name, _ = os.path.splitext(im) gt_name = 'gt_' + name + '.txt' gt_path = os.path.join(gt_root, gt_name) if not os.path.exists(gt_path): print('Ground truth file of image {0} not exists.'.format(im)) continue gt_txt = Dataset.port.read_gt_file(gt_path, have_BOM=True) img = cv2.imread(os.path.join(img_root, im)) img, gt_txt = Dataset.scale_img(img, gt_txt) tensor_img = img[np.newaxis, :, :, :] tensor_img = tensor_img.transpose((0, 3, 1, 2)) if using_cuda: tensor_img = torch.FloatTensor(tensor_img).cuda() else: tensor_img = torch.FloatTensor(tensor_img) vertical_pred, score, side_refinement = net(tensor_img) del tensor_img positive = [] negative = [] vertical_reg = [] side_refinement_reg = [] for box in gt_txt: gt_anchor = Dataset.generate_gt_anchor(img, box) positive1, negative1, vertical_reg1, side_refinement_reg1 = Net.tag_anchor(gt_anchor, score, box) positive += positive1 negative += negative1 vertical_reg += vertical_reg1 side_refinement_reg += side_refinement_reg1 if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0: batch_num -= 1 continue loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive, negative, vertical_reg, side_refinement_reg) total_loss += loss total_cls_loss += cls_loss total_v_reg_loss += v_reg_loss total_o_reg_loss += o_reg_loss end_time = time.time() total_time = end_time - start_time print('#################### Start evaluate ####################') print('loss: {0}'.format(total_loss / float(batch_num))) logger.info('Evaluate loss: {0}'.format(total_loss / float(batch_num))) print('classification loss: {0}'.format(total_cls_loss / float(batch_num))) logger.info('Evaluate vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num))) print('vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num))) logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) print('side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) print('{1} iterations for {0} seconds.'.format(total_time, batch_num)) print('##################### Evaluate end #####################') print('\n')
訓練效果與預測效果
測試效果:輸入一張圖片,給出最后的檢測結果
def infer_one(im_name, net): im = cv2.imread(im_name) im = lib.dataset_handler.scale_img_only(im) # 歸一化圖像 img = copy.deepcopy(im) img = img.transpose(2, 0, 1) img = img[np.newaxis, :, :, :] img = torch.Tensor(img) v, score, side = net(img, val=True) # 送入網絡預測 result = [] # 根據分數獲取有文字的anchor for i in range(score.shape[0]): for j in range(score.shape[1]): for k in range(score.shape[2]): if score[i, j, k, 1] > THRESH_HOLD: result.append((j, k, i, float(score[i, j, k, 1].detach().numpy()))) # nms過濾 for_nms = [] for box in result: pt = lib.utils.trans_to_2pt(box[1], box[0] * 16 + 7.5, anchor_height[box[2]]) for_nms.append([pt[0], pt[1], pt[2], pt[3], box[3], box[0], box[1], box[2]]) for_nms = np.array(for_nms, dtype=np.float32) nms_result = lib.nms.cpu_nms(for_nms, NMS_THRESH) out_nms = [] for i in nms_result: out_nms.append(for_nms[i, 0:8]) # 確定哪幾個anchors是屬於一組的 connect = get_successions(v, out_nms) # 將一組anchors合並成一條文本線 texts = get_text_lines(connect, im.shape) for box in texts: box = np.array(box) print(box) lib.draw_image.draw_ploy_4pt(im, box[0:8]) _, basename = os.path.split(im_name) cv2.imwrite('./infer_'+basename, im)
推斷時提到了get_successions
用於獲取一個預測文本行里的所有anchors,換句話說,我們得到的很多預測有字符的anchor,但是我們怎么知道哪些acnhors可以組成一個文本線呢?所以我們需要實現一個anchor合並算法,這也是CTPN代碼實現中最為困難的一步。
CTPN論文提到,文本線構造法如下:文本行構建很簡單,通過將那些text/no-text score > 0.7的連續的text proposals相連接即可。文本行的構建如下。
- 首先,為一個proposal Bi定義一個鄰居(Bj):Bj−>Bi,其中:
- Bj在水平距離上離Bi最近
- 該距離小於50 pixels
- 它們的垂直重疊(vertical overlap) > 0.7
一看理論很簡單,但是一到自己實現就困難重重了。真是應了那句“紙上得來終覺淺,絕知此事要躬行”啊!get_successions
傳入的參數是v代表每個預測anchor的h和y信息,anchors代表每個anchors的四個頂點坐標信息。
檢測效果和總結
首先看一下訓練出來的模型的文字檢測效果,為了便於觀察,我把anchor和最終合並好的文本框一並畫出:
在實現過程中的一些總結和想法:
- CTPN對於帶旋轉角度的文本的檢測效果不好,其實這是CTPN的算法特點決定的:一個個固定寬度的四邊形是很難合並出一個准確的文本框,比如一些anchors很難組成一組,即使組成一組了也很難精確恢復成完整的精確的文本矩形框(推斷階段的缺點)。當然啦,對於水平排布的文本檢測,個人認為這個算法思路還是很奏效的。
- CTPN中的side-refinement其實作用不大,如果我們檢測出來的文本是直接拿出識別,這個side-refinement優化的幾個像素差別其實可以忽略;
- CTPN的中間步驟有點多:從anchor標簽的生成到中間計算loss再到最后推斷的文本線生成步驟,都會引入一定的誤差,這個缺點也是EAST論文中所提出的。訓練的步驟越簡潔,中間過程越少,精度更有保障。
- CTPN的算法得出的效果可以看出,准確率低但召回率高。這種基於16像素的anchor識別感覺對於一些大的非文字圖標(比如路標)誤判率相當高,這是源於其anchor的寬度實在太小了,盡管使用了lstm關聯周圍anchor,但是我還是認為有點“一葉障目”的感覺。所以CTPN對於過大或過小的文字檢測效果不會太好。
- CTPN是個比較老的算法了(2016年),其思路在當年還是很創新的,但是也有很多弊端。現在提出的新方法已經基本解決了這些不足之處,比如EAST,PixelNet都是一些很優秀的新算法。
CTPN的完整實現可以參考博主: Github
https://www.cnblogs.com/skyfsm/p/10054386.html