深度學習ocr交流qq群:1020395892
論文地址:https://link.zhihu.com/?target=https%3A//arxiv.org/pdf/1911.08947.pdf
github:https://github.com/MhLiao/DB
搗鼓DB有一段時間了,年前就開始訓練了。
問題1:不收斂,原來是我的數據標簽有問題,雙cuda
剛開始是訓練死活不收斂,訓練ic15數據集也不收斂,官方數據集都不收斂??不由懷疑肯定是哪里搞錯了。我的是cuda8的,pytorch1.1還是1.2的,懷疑可能是一定是需要cuda10,
於是冒着重裝系統的危險搗鼓雙cuda,首先升級驅動418,之前是384,cuda10需要驅動418,418也向下兼容cuda8.一頓操作倒是很順利,/usr/local/下面是cuda-8.0 cuda-10.0 還有一個軟鏈接,想用哪一個就修改一下軟鏈接就可以了。
於是歡快的用anconda裝pytorch-cuda10版本的。現在裝的是pytorch 1.3.1 py3.7_cuda10.0.130_cudnn7.6.3_0 pytorch;之前敲裝pytorch默認安裝cuda10.1的pytorch版本,但是運行DB報錯,報cuda的問題,弄了很久,發現我本地是10.0版本的,然后試着找cuda10.0版本的pytorch,找到,裝好再運行就不報錯了。
但是不收斂的問題依舊存在,不知道咋搞的。后來看源碼數據處理那塊,發現是會去掉最后一位,因為ic15數據標簽格式是:
58,80,191,71,194,114,61,123,fusionopolis
147,21,176,21,176,36,147,36,###
去掉最后一位的文本內容,而我的數據集只有坐標,並且有4個點有14個點的。
在data->image_dataset.py line71
num_points = math.floor((len(line) - 1) / 2) * 2 ###去掉“-”就可以
還有line41
gt_path=[self.data_dir[i]+'/train_gts/'+timg.strip().replace(".jpg","")+'.txt' for timg in image_list] ##.replace(".jpg","")表示gt與img名字一樣
然后訓練我40多萬的數據集,訓練幾天loss維持在1左右吧,測試也可以而且測其他的文本魯棒性也很好。
問題2:--polygon效果沒有四個點的好 --image_short_side(需要是32的倍數)
有一個問題是這些都是4個點,也支持輪廓點的,需要加--polygon 但是加了這個效果不好,后面再看看。
CUDA_VISIBLE_DEVICES=0 python demo.py DB-master/experiments/seg_detector/merge_data_resnet50_deform_thre-SRC.yaml --visualize --resume /DB-master/myfile/model_epoch_13_minibatch_396000-20200220 --image_path /data_2/everyday/0220/snapshot13.png --polygon --box_thresh 0.35
還可以加一個參數--image_short_side,默認是736,這個參數需要是32的倍數。
效果圖,這些圖是不在數據集里面的其他圖:
********************************示例1:
********************************示例2:
********************************示例3:
論文、源碼理解:
作者的源碼實在是太復雜了啊,直接看蒙圈了。各種動態類啊,動態創建啊,不好調試啊,斷點不好跟蹤。
源碼我是看了很久吧,靠近一個月,加上春節疫情這段時間在家,效率很不高,特別是碰到看不懂的。
我現在都不明白哪個類是什么時候就創建好了的,我只是把每個文件都看的很熟了。
根據yaml動態創建類
比如文件夾concern里面有個config.py
class State:
def __init__(self, autoload=True, default=None):
self.autoload = autoload
self.default = default
class StateMeta(type):
def __new__(mcs, name, bases, attrs):
....
class Configurable(metaclass=StateMeta):
....
然后后面所有的類都是繼承Configurable這個類。ヾ(。`Д´。),metaclass是叫元類的一個東東,https://www.cnblogs.com/yssjun/p/9832526.html
所有的類都是通過getattr(self, name)這個玩意動態創建,之所以要動態創建,是為了方面配置yaml可以多做實驗,可以對於我們就看起來懵逼了。看yaml文件:
import:
- 'experiments/seg_detector/base_totaltext.yaml'
package: []
define:
- name: 'Experiment'
class: Experiment
structure:
class: Structure
builder:
class: Builder
model: SegDetectorModel
model_args:
backbone: deformable_resnet50
decoder: SegDetector
decoder_args:
adaptive: True
in_channels: [256, 512, 1024, 2048]
k: 50
loss_class: L1BalanceCELoss
representer:
class: SegDetectorRepresenter
max_candidates: 1000
measurer:
class: QuadMeasurer
visualizer:
class: SegDetectorVisualizer
train:
class: TrainSettings
data_loader:
class: DataLoader
...
各種類,程序運行的時候都是讀取的這些來創建與初始化類的。
數據預處理
數據處理經過了7個步驟對應7個類!需要經過什么處理在base_***.ymal和base.ymal指定數據處理的類和參數,比如yaml文件中:
processes:
- class: AugmentDetectionData
augmenter_args:
- ['Fliplr', 0.5]
- {'cls': 'Affine', 'rotate': [-10, 10]}
- ['Resize', [0.5, 3.0]]
only_resize: False
keep_ratio: False
- class: RandomCropData
size: [640, 640]
max_tries: 10
- class: MakeICDARData
- class: MakeSegDetectionData
- class: MakeBorderMap
- class: NormalizeImage
- class: FilterKeys
superfluous: ['polygons', 'filename', 'shape', 'ignore_tags', 'is_training']
讀源碼的時候我並不知道在哪里創建與初始化了這些類,后面再看吧。我只是在data->image_dataset.py文件的def getitem(self, index, retry=0):函數打斷點:
可以看到,循環在預處理,一個接這一個。想看哪個就提前去哪個類打上斷點。
其中,make_border_map.py這個是為了做threshold的標簽的,沒有看懂,但是看效果圖是高亮文字塊邊緣,其余部分都賦值0.3,后面再說這塊東西。
data文件夾下面有一些py文件和data文件夾下面的processes文件夾下面的py貌似是一樣的,實際運行的時候發現有些運行的是data下面的py有些是processes文件夾下面的,懵圈+10086
model.forward()函數執行步驟
trainer.py里面的一個函數:
def train_step(self, model, optimizer, batch, epoch, step, **kwards):
optimizer.zero_grad()
results = model.forward(batch, training=True)
....
results = model.forward(batch, training=True)后面是跑到哪里呢?
然后我看這個model怎么初始化的,該文件上面:
def init_model(self):
model = self.structure.builder.build(
self.device, self.experiment.distributed, self.experiment.local_rank)
return model
然后:
structure->Builder的build函數如下:
def build(self, device, distributed=False, local_rank: int = 0):
Model = getattr(structure.model,self.model)
model = Model(self.model_args, device,
distributed=distributed, local_rank=local_rank)
return model
---structure.model在yaml文件中指定:
model: SegDetectorModel
所以我就去找類SegDetectorModel
class SegDetectorModel(nn.Module):
def __init__(self, args, device, distributed: bool = False, local_rank: int = 0):
super(SegDetectorModel, self).__init__()
from decoders.seg_detector_loss import SegDetectorLossBuilder
self.model = BasicModel(args)
再繼續:
class BasicModel(nn.Module):
def __init__(self, args):
nn.Module.__init__(self)
self.backbone = getattr(backbones, args['backbone'])(**args.get('backbone_args', {}))
self.decoder = getattr(decoders, args['decoder'])(**args.get('decoder_args', {}))
def forward(self, data, *args, **kwargs):
returbone: den self.decoder(self.backbone(data), *args, **kwargs)
上面的:在yaml文件中寫了:
backbone: deformable_resnet50
decoder: SegDetector
所以return self.decoder(self.backbone(data), *args, **kwargs) 這一句就跑了兩個類里面的forward()函數
總結:results = model.forward(batch, training=True)執行步驟是:
step1:
SegDetectorModel下面的forward:
if isinstance(batch, dict):
data = batch['image'].to(self.device)
else:
data = batch.to(self.device)
data = data.float()
pred = self.model(data, training=self.training)
step2:
然后調用 BasicModel的forward:
backbone就是deformable_resnet50
decoder就是SegDetector
def forward(self, data, *args, **kwargs):
return self.decoder(self.backbone(data), *args, **kwargs)
step3:resnet50的forward ##self.backbone(data) == resnet50
step4:SegDetector的forward ##self.decoder == SegDetector
我就是先在一個類中打斷點,然后我感覺接下來是跑到這個類的forward函數,就在這打斷點,這樣是可以的,我就是這么摸索出來的。
網絡的流程
所以摸索出網絡的大概:
先是通過resnet+可變形卷積得到feature_map X2,X3,X4,X5 (注意resnet中嵌套了可變形卷積---可以參考https://www.cnblogs.com/yanghailin/p/12321832.html)
然后送到SegDetector的forward數概率圖函數,一頓卷積-池化-上采樣-bn-relu,累加合並
c2, c3, c4, c5 = features
in5 = self.in5(c5)
in4 = self.in4(c4)
in3 = self.in3(c3)
in2 = self.in2(c2)
out4 = self.up5(in5) + in4 # 1/16
out3 = self.up4(out4) + in3 # 1/8
out2 = self.up3(out3) + in2 # 1/4
p5 = self.out5(in5)
p4 = self.out4(out4)
p3 = self.out3(out3)
p2 = self.out2(out2)
fuse = torch.cat((p5, p4, p3, p2), 1)
p5,p4,p3,p2的尺寸都是[n,64,160,160],fuse的尺寸是[n,256,160,160];再然后:
binary = self.binarize(fuse)
thresh = self.thresh(fuse)
再一頓卷積、bn、relu、反卷積、sigmoid操作得到binary,其尺寸是[n,1,640,640]和輸入尺寸一樣
再一頓卷積、bn、relu、上采樣、sigmoid操作得到thresh,其尺寸是[n,1,640,640]和輸入尺寸一樣
再計算:
thresh_binary = torch.reciprocal(1 + torch.exp(-self.k * (binary - thresh))) 論文中的那個公式(如上公式)
binary是學到的分數概率圖
thresh是學到的文字塊邊界圖
thresh_binary是由binary和thresh根據公式計算出來的
后面就是loss約束,L1BalanceCELoss
def forward(self, pred, batch):
bce_loss = self.bce_loss(pred['binary'], batch['gt'], batch['mask'])
metrics = dict(bce_loss=bce_loss)
if 'thresh' in pred:
l1_loss, l1_metric = self.l1_loss(pred['thresh'], batch['thresh_map'], batch['thresh_mask'])
dice_loss = self.dice_loss(pred['thresh_binary'], batch['gt'], batch['mask'])
metrics['thresh_loss'] = dice_loss
loss = dice_los數概率圖s + self.l1_scale * l1_loss + bce_loss * self.bce_scale
metrics.update(**l1_metric)
else:
loss = bce_loss
return loss, metrics
可以看到:
binary與thresh_binary的標簽都是用的gt
thresh的標簽用的thresh_map
自適應閾值
這個問題困擾我很久,單看這個公式:
p可以理解,就是有文字的區域有值,0.9以上,沒有文字區域黑的,為0
T呢,T是一個只有文字邊界才有值的,其他地方為0,那所有的像素都是經過這個公式,得到thresh_binary,這個合適嗎?
然后自己慢慢從一開始制作的標簽入手,gt就是我們標注好的,p就是gt,那個T的標簽threshold map是根據文字邊界做的,T的標簽threshold map到底是啥,
threshold map是將文本框分別向內向外收縮和擴張d(根據第一步收縮時計算得到)個像素,然后計算收縮框和擴張框之間差集部分里每個像素點到原始圖像邊界的歸一化距離。是根據一個算法跑出來的,看了源碼就是一堆計算,沒有細看。然后我就把gt與threshold map顯示出來更加直觀。
其實一開始就有個問題,gt是標注好的,為啥還要經過psenet里面的縮水操作?
看到圖自然就會明白了,這就是這個算法的特別之處了。
分別是原圖,gt圖,threshold map圖。
這里再說下threshold map圖,非文字邊界處都是灰色的,這是因為統一加了0.3,所有最小值是0.3,這是為了后面有用的。
這里其實還看不清,我們把src+gt+threshold map看看。
可以看到:
p的ground truth是標注縮水之后
T的ground truth是文字塊邊緣分別向內向外收縮和擴張
p與T是公式里面的那兩個變量。
再看這個公式與曲線圖:
P和T我們就用ground truth帶入來理解:
p網絡學的文字塊內部,
T網絡學的文字邊緣,兩者計算得到B。
B的ground truth也是標注縮水之后,和p用的同一個。
在實際操作中,作者把除了文字塊邊緣的區域置為0.3.應該就是為了當在非文字區域,
P=0,T=0.3,x=p-T<0這樣拉到負半軸更有利於區分。可以看上面的曲線圖。
同時,作者在論文中也寫了之所以這么做的原因:
首先:
Threshold map本身可以在沒有監督的情況下學到。通過可視化的觀察,發現threshold map會highlight文字區域的邊緣。因此作者利用文字區域的標注對threshold進行監督以獲得更好的結果。如下論文中的圖:
c圖是沒有監督的效果,d是有監督的
其次:求導,更容易區分正負樣本
(b) Derivative of l+ . (c) Derivative of l− .
x=p-T,我們上面討論的
x>0是縮水之后的文字塊內部
x<0是縮水之后的文字塊外部
正負樣本的導數在x>0與x<0處有較大的區別,k=1時區別不大,當k=50時,可以看到放大了這種區別。
作者論文中說可微分二值化的好處:
The differentiable binarization with adaptive thresholds can not only help differentiate text regions
from the background, but also separate text instances which are closely jointed.
數據增強的過程中可能存在bug
我把標簽圖與threshold map圖show出來看,偶爾會異常,由於存在隨機的過程,我沒法復現。此問題已經在github上面向作者提問了,沒有回復。https://github.com/MhLiao/DB/issues/120
也有可能是我把代碼改了一點點導致的,后面有時間再排查。
暫時想到的就是這些了,后面的有想法再補充,歡迎一起討論。
小弟不才,同時謝謝友情贊助!