EAST結構分析+pytorch源碼實現



EAST結構分析+pytorch源碼實現

一. U-Net的前車之鑒

在介紹EAST網絡之前我們先介紹一下前面的幾個網絡,看看這個EAST網絡怎么來的?為什么來的?

當然這里的介紹僅僅是引出EAST而不是詳細的講解其他網絡,有需要的讀者可以去看看這三個優秀網絡。

1.1 FCN網絡結構

​ FCN網絡,在之前FCN從原理到代碼的理解已經詳細分析了,有需要的可以去看看,順便跑一跑代碼。

圖1-1

  • 網絡的由來

不管是識別(傳統機器學習、CNN)還是檢測(SSD、YOLO等),都只是基於大塊的特征進行的,檢測之后都是以長方形去表示檢測結果,由於這是其算法內部回歸的結果導致,而且feature map經過卷積一直減小,如果強行進行256X256512X512的插值,那么結果可以想象,邊界非常不好。

那么如何實現圖1-1所示的結果呢?把每個像素都進行分割?

  • 網絡的成果

FCN給出的方法是使用反卷積進行上采樣操作,使得經過CNN之后減小的圖能夠恢復大小。

當然作者還提出一個好方法,不同的feature map進行組合,使得感受野進行擴充。

注釋:筆者認為使用反卷積有兩個作用,其一是使得計算LOSS比較方便,標簽和結果可以直接進行計算。其二是可以進行參數的學習,更為智能化。

1.2 U-NET網絡

U-net網絡之前沒怎么看過,現在也僅僅是大概看了論文和相關資料,內部實現不是很了解。

圖1-2

  • 網絡的由來

FCN完全可以做到基於像素點的分割,為什么還要這個U-net網絡啊?

FCN網絡檢測的效果還可以,但是其邊緣的處理就特別的差。雖然說多個層進行合並,但是合並的內容雜亂無章,導致最后的信息沒有完全得到。

總的來說FCN分割的效果不夠,精度也不夠。

  • 網絡的成果

U-net提出了對稱的網絡結構,使得網絡參數的學習效果更好(為什么對稱網絡學習更好,這個理解不透,如果是結果再放大一倍使得不對稱不也一樣嗎?感覺還是網絡結構設計的好,而不是對稱)

不同feature map合並的方式更加優化,使得在邊緣分割(細節)上更加優秀。

網絡架構清晰明了,分割效果也很好,現在醫學圖像分割領域還能看見身影。

1.3 CTPN網絡

剛開始准備使用CTPN進行文本的檢測,所以看了一些相關資料,致命缺點是不能檢測帶角度文字和網絡比較復雜。

圖1-3

  • 網絡的由來

文本檢測和其他檢測卻別很大,比如用SSD檢測文本就比較困難(邊緣檢測不好),如何針對文本進行檢測?

  • 網絡的成果

CTPN網絡有很多創造的想法-->>

目標分割小塊,然后一一進行檢測,針對文本分割成height>width的方式,使得檢測的邊緣更為精確。

使用BiLSTM對小塊進行連接,針對文本之間的相關性。

CTPN想法具有創造性,但是太過復雜。

  1. 首先樣本的制作麻煩
  2. 每個小框進行回歸,框的大小自己定義
  3. 邊緣特意進行偏移處理
  4. 使用RNN進行連接

檢測水平效果還是不錯的,但是對於傾斜的文本就不行了。

為什么不加一個angle進行回歸?

本就很復雜的網絡,如果再給每個小box加一個angle參數會更復雜,當然是可以實施的。

二. EAST結構分析

2.1 結構簡述

EAST原名為: An Efficient and Accurate Scene Text Detector

結構:檢測層(PVANet) + 合並層 + 輸出層

圖2-1

下圖圖2-2是檢測效果,任意角度的文本都可以檢測到。

注意:EAST只是一個檢測網絡,如需識別害的使用CRNN等識別網絡進行后續操作。

圖2-2

具體網絡在2-2節進行詳細介紹=====>>>

2.2 結構詳解

  • 整體結構

EAST根據他的名字,我們知道就是高效的文本檢測方法。

上面我們介紹了CTPN網絡,其標簽制作很麻煩,結構很復雜(分割成小方框然后回歸還要RNN進行合並)

看下圖圖2-3,只要進行類似FCN的結構,計算LOSS就可以進行訓練。測試的時候走過網絡,運行NMS就可以得出結果。太簡單了是不是?

圖2-3

  • 特征提取層

特征的提取可以任意網絡(VGG、RES-NET等檢測網絡),本文以VGG為基礎進行特征提取。這個比較簡單,看一下源碼就可以清楚,見第四章源碼分析

  • 特征合並層

在合並層中,首先在定義特征提取層的時候把需要的輸出給保留下來,通過forward函數把結構進行輸出。之后再合並層調用即可

如下代碼定義,其中合並的過程再下面介紹

#提取VGG模型訓練參數
class extractor(nn.Module):
	def __init__(self, pretrained):
		super(extractor, self).__init__()
		vgg16_bn = VGG(make_layers(cfg, batch_norm=True))
		if pretrained:
			vgg16_bn.load_state_dict(torch.load('./pths/vgg16_bn-6c64b313.pth'))
		self.features = vgg16_bn.features
	
	def forward(self, x):
		out = []
		for m in self.features:
			x = m(x)
			#提取maxpool層為后續合並
			if isinstance(m, nn.MaxPool2d):
				out.append(x)
		return out[1:]
  • 特征合並層

合並特征提取層的輸出,具體的定義如下代碼所示,代碼部分已經注釋.

其中x中存放的是特征提取層的四個輸出

	def forward(self, x):

		y = F.interpolate(x[3], scale_factor=2, mode='bilinear', align_corners=True)
		y = torch.cat((y, x[2]), 1)
		y = self.relu1(self.bn1(self.conv1(y)))		
		y = self.relu2(self.bn2(self.conv2(y)))
		
		y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)
		y = torch.cat((y, x[1]), 1)
		y = self.relu3(self.bn3(self.conv3(y)))		
		y = self.relu4(self.bn4(self.conv4(y)))
		
		y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)
		y = torch.cat((y, x[0]), 1)
		y = self.relu5(self.bn5(self.conv5(y)))		
		y = self.relu6(self.bn6(self.conv6(y)))
		
		y = self.relu7(self.bn7(self.conv7(y)))
		return y
  • 輸出層

輸出層包括三個部分,這里以RBOX為例子,發現網上都沒有QUAN為例子的?

首先QUAN的計算是為了防止透視變換的存在,正常情況下不存在這些問題,正常的斜框可以解決。

因為QUAN的計算沒啥好處,前者已經完全可以解決正常的檢測問題,后者回歸四個點相對來說較為困難(如果文本變化較大就更困難,所以SSD和YOLO無法檢測文本的原因)。

如果想得到特殊的文本,基本考慮別的網絡了(比如彎曲文字的檢測)

	def forward(self, x):
		score = self.sigmoid1(self.conv1(x))
		loc   = self.sigmoid2(self.conv2(x)) * self.scope
		angle = (self.sigmoid3(self.conv3(x)) - 0.5) * math.pi
		geo   = torch.cat((loc, angle), 1) 
		return score, geo

三. EAST細節分析

3.1 標簽制作

注意:這里是重點和難點!!!

文章說要把標簽向里縮進0.3

筆者認為這樣做的目的是提取到更為准確的信息,不論是人工標注的好與不好,我們按照0.3縮小之后提取的特征都是全部的文本信息。

但是這樣做也會丟失一些邊緣信息,如果按照上述的推斷,那么SSD或YOLO都可以這樣設計標簽了。

作者肯定是經過測試的,有好處有壞處吧!

圖3-1

標簽格式為:5個geometry(4個location+1個angle) + 1個score ==6 × N × M

其中(b)為score圖 ,(d)為四個location圖, (e)為angle圖

上圖可能看的不清楚,下面以手繪圖進行說明:

圖3-2

上圖可能看不清楚,下面再用文字大概說一下吧!

  1. 先進行0.3縮放,這個時候的圖就是score
  2. 沒縮放的圖像為基准,畫最小外接矩形,這個外接矩形的角度就是angle。這個大小是縮放的的圖大小。感覺直接以score圖做角度也一樣的。
  3. score圖的每個像素點到最小外接矩形的距離為四個location圖。

3.2 LOSS計算

LOSS計算就比較簡單的,直接回歸location、angle、score即可。

	def forward(self, gt_score, pred_score, gt_geo, pred_geo, ignored_map):
		#圖像中不存在目標直接返回0
		if torch.sum(gt_score) < 1:
			return torch.sum(pred_score + pred_geo) * 0
		#score loss 采用Dice方式計算,沒有采用log熵計算,為了防止樣本不均衡問題
		classify_loss = get_dice_loss(gt_score, pred_score*(1-ignored_map))
		#geo loss采用Iou方式計算(計算每個像素點的loss)
		iou_loss_map, angle_loss_map = get_geo_loss(gt_geo, pred_geo)
		#計算一整張圖的loss,angle_loss_map*gt_score去除不是目標點的像素(感覺這句話應該放在前面減少計算量,放在這里沒有減少計算loss的計算量)
		angle_loss = torch.sum(angle_loss_map*gt_score) / torch.sum(gt_score)
		iou_loss = torch.sum(iou_loss_map*gt_score) / torch.sum(gt_score)
		geo_loss = self.weight_angle * angle_loss + iou_loss#這里的權重設置為1
		print('classify loss is {:.8f}, angle loss is {:.8f}, iou loss is {:.8f}'.format(classify_loss, angle_loss, iou_loss))
		return geo_loss + classify_loss

注意:這里score的LOSS使用Dice方式,因為普通的交叉熵無法解決樣本不均衡問題!!!

圖3-3

3.3 NMS計算

NMS使用的是locality NMS,也就是為了針對EAST而提出來的。

首先我們先來看看這個LANMS的原理和過程:

import numpy as np
from shapely.geometry import Polygon

def intersection(g, p):
    #取g,p中的幾何體信息組成多邊形
    g = Polygon(g[:8].reshape((4, 2)))
    p = Polygon(p[:8].reshape((4, 2)))

    # 判斷g,p是否為有效的多邊形幾何體
    if not g.is_valid or not p.is_valid:
        return 0

    # 取兩個幾何體的交集和並集
    inter = Polygon(g).intersection(Polygon(p)).area
    union = g.area + p.area - inter
    if union == 0:
        return 0
    else:
        return inter/union

def weighted_merge(g, p):
    # 取g,p兩個幾何體的加權(權重根據對應的檢測得分計算得到)
    g[:8] = (g[8] * g[:8] + p[8] * p[:8])/(g[8] + p[8])
    
    #合並后的幾何體的得分為兩個幾何體得分的總和
    g[8] = (g[8] + p[8])
    return g

def standard_nms(S, thres):
    #標准NMS
    order = np.argsort(S[:, 8])[::-1]
    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
        inds = np.where(ovr <= thres)[0]
        order = order[inds+1]
        
    return S[keep]

def nms_locality(polys, thres=0.3):
    '''
    locality aware nms of EAST
    :param polys: a N*9 numpy array. first 8 coordinates, then prob
    :return: boxes after nms
    '''
    S = []    #合並后的幾何體集合
    p = None   #合並后的幾何體
    for g in polys:
        if p is not None and intersection(g, p) > thres:    #若兩個幾何體的相交面積大於指定的閾值,則進行合並
            p = weighted_merge(g, p)
        else:    #反之,則保留當前的幾何體
            if p is not None:
                S.append(p)
            p = g
    if p is not None:
        S.append(p)
    if len(S) == 0:
        return np.array([])
    return standard_nms(np.array(S), thres)

if __name__ == '__main__':
    # 343,350,448,135,474,143,369,359
    print(Polygon(np.array([[343, 350], [448, 135],
                            [474, 143], [369, 359]])).area)

別看那么多代碼,講的很玄乎,其實很簡單:

  1. 遍歷每個預測的框,然后按照交集大於某個值K就合並相鄰的兩個框。
  2. 合並完之后就按照正常NMS消除不合理的框就行了。

注意: 為什么相鄰的框合並?

  1. 因為每個像素預測一個框(不明白就自己去看上面LOSS計算),一個目標的幾百上千個框基本都是重合的(如果預測的准的話),所以說相鄰的框直接進行合並就行了。
  2. 其實豎直和橫向都合並一次最好,反正原理一樣的。

四. Pytorch源碼分析

源碼就不進行分析了,上面已經說得非常明白了,基本每個難點和重點都說到了。

有一點小bug,現進行說明:

  1. 訓練的時候出現孔樣本跑死
SampleNum = 3400 #定義樣本數量,應對空標簽的文本bug,臨時處理方案
class custom_dataset(data.Dataset):
	def __init__(self, img_path, gt_path, scale=0.25, length=512):
		super(custom_dataset, self).__init__()
		self.img_files = [os.path.join(img_path, img_file) for img_file in sorted(os.listdir(img_path))]
		self.gt_files  = [os.path.join(gt_path, gt_file) for gt_file in sorted(os.listdir(gt_path))]
		self.scale = scale
		self.length = length

	def __len__(self):
		return len(self.img_files)

	def __getitem__(self, index):
		with open(self.gt_files[index], 'r') as f:
			lines = f.readlines()
		while(len(lines)<1):
			index = int(SampleNum*np.random.rand())
			with open(self.gt_files[index], 'r') as f:
				lines = f.readlines()
		vertices, labels = extract_vertices(lines)
		
		img = Image.open(self.img_files[index])
		img, vertices = adjust_height(img, vertices) 
		img, vertices = rotate_img(img, vertices)
		img, vertices = crop_img(img, vertices, labels, self.length,index)
		transform = transforms.Compose([transforms.ColorJitter(0.5, 0.5, 0.5, 0.25), \
                                        transforms.ToTensor(), \
                                        transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
		
		score_map, geo_map, ignored_map = get_score_geo(img, vertices, labels, self.scale, self.length)
		return transform(img), score_map, geo_map, ignored_map
  1. 測試的時候讀取PIL會出現RGBA情況
	img_path    = './013.jpg'
	model_path  = './pths/model_epoch_225.pth'
	res_img     = './res.bmp'
	img = Image.open(img_path)
	img = np.array(img)[:,:,:3]
	img = Image.fromarray(img)
  • 后續工作
  1. 這個代碼感覺有點問題,訓練速度很慢,猜測是數據處理部分。
  2. 原版EAST每個點都進行回歸,太浪費時間了,后續參考AdvanceEAST進行修改,同時加個人理解優化
  3. 網絡太大了,只適合服務器或者PC上跑,當前網絡已經修改到15MB,感覺還是有點大。
  4. 后續還要加識別部分,困難重重。。。。。。

這里的代碼都是github上的,筆者只是搬運工而已!!!

原作者下載地址

五. 第一次更新內容

  • 2019-6-30更新

之前提到這個工程的代碼有幾個缺陷,在這里進行詳細的解決

  1. 訓練速度很慢

這是由於源代碼的數據處理部分編寫有問題導致,隨機crop中對於邊界問題處理
以下給出解決方案,具體修改請讀者對比源代碼即可:

def crop_img(img, vertices, labels, length, index):
	'''crop img patches to obtain batch and augment
	Input:
		img         : PIL Image
		vertices    : vertices of text regions <numpy.ndarray, (n,8)>
		labels      : 1->valid, 0->ignore, <numpy.ndarray, (n,)>
		length      : length of cropped image region
	Output:
		region      : cropped image region
		new_vertices: new vertices in cropped region
	'''
	try:
		h, w = img.height, img.width
		# confirm the shortest side of image >= length
		if h >= w and w < length:
			img = img.resize((length, int(h * length / w)), Image.BILINEAR)
		elif h < w and h < length:
			img = img.resize((int(w * length / h), length), Image.BILINEAR)
		ratio_w = img.width / w
		ratio_h = img.height / h
		assert(ratio_w >= 1 and ratio_h >= 1)

		new_vertices = np.zeros(vertices.shape)
		if vertices.size > 0:
			new_vertices[:,[0,2,4,6]] = vertices[:,[0,2,4,6]] * ratio_w
			new_vertices[:,[1,3,5,7]] = vertices[:,[1,3,5,7]] * ratio_h
		#find four limitate point by vertices
		vertice_x = [np.min(new_vertices[:, [0, 2, 4, 6]]), np.max(new_vertices[:, [0, 2, 4, 6]])]
		vertice_y = [np.min(new_vertices[:, [1, 3, 5, 7]]), np.max(new_vertices[:, [1, 3, 5, 7]])]
		# find random position
		remain_w = [0,img.width - length]
		remain_h = [0,img.height - length]
		if vertice_x[1]>length:
			remain_w[0] = vertice_x[1] - length
		if vertice_x[0]<remain_w[1]:
			remain_w[1] = vertice_x[0]
		if vertice_y[1]>length:
			remain_h[0] = vertice_y[1] - length
		if vertice_y[0]<remain_h[1]:
			remain_h[1] = vertice_y[0]

		start_w = int(np.random.rand() * (remain_w[1]-remain_w[0]))+remain_w[0]
		start_h = int(np.random.rand() * (remain_h[1]-remain_h[0]))+remain_h[0]
		box = (start_w, start_h, start_w + length, start_h + length)
		region = img.crop(box)
		if new_vertices.size == 0:
			return region, new_vertices

		new_vertices[:,[0,2,4,6]] -= start_w
		new_vertices[:,[1,3,5,7]] -= start_h
	except IndexError:
		print("\n crop_img function index error!!!\n,imge is %d"%(index))
	else:
		pass
	return region, new_vertices
  1. LOSS剛開始收斂下降,到后面就呈現抖動(像過擬合現象),檢測效果角度很差

由於Angle Loss角度計算錯誤導致,請讀者閱讀作者原文進行對比

def find_min_rect_angle(vertices):
	'''find the best angle to rotate poly and obtain min rectangle
	Input:
		vertices: vertices of text region <numpy.ndarray, (8,)>
	Output:
		the best angle <radian measure>
	'''
	angle_interval = 1
	angle_list = list(range(-90, 90, angle_interval))
	area_list = []
	for theta in angle_list: 
		rotated = rotate_vertices(vertices, theta / 180 * math.pi)
		x1, y1, x2, y2, x3, y3, x4, y4 = rotated
		temp_area = (max(x1, x2, x3, x4) - min(x1, x2, x3, x4)) * \
                    (max(y1, y2, y3, y4) - min(y1, y2, y3, y4))
		area_list.append(temp_area)
	
	sorted_area_index = sorted(list(range(len(area_list))), key=lambda k : area_list[k])
	min_error = float('inf')
	best_index = -1
	rank_num = 10
	# find the best angle with correct orientation
	for index in sorted_area_index[:rank_num]:
		rotated = rotate_vertices(vertices, angle_list[index] / 180 * math.pi)
		temp_error = cal_error(rotated)
		if temp_error < min_error:
			min_error = temp_error
			best_index = index

	if angle_list[best_index]>0:
		return (angle_list[best_index] - 90) / 180 * math.pi

	return (angle_list[best_index]+90) / 180 * math.pi
  1. 修改網絡從50MB到15MB,對於小樣本訓練效果很好

這里比較簡單,直接修改VGG和U-NET網絡feature map即可

cfg = [32, 32, 'M', 64, 64, 'M', 128, 128, 128, 'M', 256, 256, 256, 'M', 256, 256, 256, 'M']
#合並不同的feature map
class merge(nn.Module):
	def __init__(self):
		super(merge, self).__init__()

		self.conv1 = nn.Conv2d(512, 128, 1)
		self.bn1 = nn.BatchNorm2d(128)
		self.relu1 = nn.ReLU()
		self.conv2 = nn.Conv2d(128, 128, 3, padding=1)
		self.bn2 = nn.BatchNorm2d(128)
		self.relu2 = nn.ReLU()

		self.conv3 = nn.Conv2d(256, 64, 1)
		self.bn3 = nn.BatchNorm2d(64)
		self.relu3 = nn.ReLU()
		self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
		self.bn4 = nn.BatchNorm2d(64)
		self.relu4 = nn.ReLU()

		self.conv5 = nn.Conv2d(128, 32, 1)
		self.bn5 = nn.BatchNorm2d(32)
		self.relu5 = nn.ReLU()
		self.conv6 = nn.Conv2d(32, 32, 3, padding=1)
		self.bn6 = nn.BatchNorm2d(32)
		self.relu6 = nn.ReLU()

		self.conv7 = nn.Conv2d(32, 32, 3, padding=1)
		self.bn7 = nn.BatchNorm2d(32)
		self.relu7 = nn.ReLU()
		#初始化網絡參數
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
				if m.bias is not None:
					nn.init.constant_(m.bias, 0)
			elif isinstance(m, nn.BatchNorm2d):
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 0)
  1. 小的字體檢測很好,大的字體檢測不到(部分檢測不到)情況

這里是模仿AdvanceEAST的方法進行訓練,先在小圖像進行訓練,然后遷移到大圖像即可。

意思就是先將圖像縮小到254254訓練得到modeul_254.pth
然后在將圖像resize到384
384,網絡參數使用modeul_254.pth,訓練得到modeul_384.pth
。。。一次進行512或者更大的圖像即可

  1. 針對圖像訓練和檢測的慢(相對於其他檢測網絡)

這里需要根據原理來說了,是因為全部的像素都需要預測和計算loss,可以看看AdvanceEAST的網絡進行處理即可

  1. 修改網絡說明

訓練樣本3000
測試樣本100
檢測精度85%,IOU准確度80%
5個epoch收斂結束(這些都是這里測試的)
兩塊1080TI,訓練時間10分鍾左右

這里是我完整的工程


五. 參考文獻


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM