點藍色字關注“機器學習算法工程師”
設為星標,干貨直達!
對於實例分割來說,主流的做法還是基於先檢測后分割的流程,比如最流行的Mask RCNN模型就是構建在Faster RCNN基礎上。目前基於one-stage的物體檢測模型已經在速度和性能上超越two-stage模型,同樣地,大家也希望能找到one-stage的實例分割模型來替換Mask RCNN。目前這方面的工作主要集中在三個方向:
- Mask encoding:對2D mask編碼為1D representation,比如PolarMask基於輪廓構建了polar representation,而MEInst則將mask壓縮成一個1D vector,這樣預測mask就類似於box regress那樣直接加在one-stage檢測模型上;
- 分離檢測和分割:將檢測和分割分離成兩個部分這樣可以並行化,如YOLACT在檢測模型基礎上額外預測了一系列prototype masks,然后檢測部分每個instance會預測mask coeffs來組合masks來產生instance mask,BlendMask是對這一工作的進一步改進;
- 不依賴檢測的實例分割:不依賴檢測框架直接進行實例分割,TensorMask和SOLO屬於此種類型,前者速度太慢,后者速度和效果都非常好;
對於mask encoding方法,雖然實現起來比較容易,但是往往會造成2D mask的細節損失,所以性能上會差一點;分離檢測和分割,對於分割部分可以像語義分割那樣預測global mask,分辨率上會更高(要知道Mask RCNN的mask分辨率只有28x28),但是這種方法需要一種好的方式來產生instance mask;不依賴檢測而直接進行實例分割這可能是未來的趨勢。這里介紹的CondInst,其實屬於第二種,但是它與YOLACT不同,其核心點是檢測部分為每個instance預測不同的mask head,然后基於global mask features來產生instance mask,思路非常簡單,而且實現起來也極其容易(已經開源在AdelaiDet),更重要的是速度和效果上均超越Mask RCNN。
整體結構
CondInst是構建在物體檢測模型FCOS之上的(CondInst和FCOS是同一個作者),所以理解CondInst必須先理解FCOS,可以參考之前關於FCOS的介紹文章(FCOS),但其實CondInst也可以依賴其他的one-stage模型,CondInst整體結構如下圖所示:
相比FCOS,CondInst多了一個mask branch,其得到的mask features將作為mask FCN的輸入來生成最終的instance mask,這個mask features來自於P3,所以大小是輸入圖像的1/8。另外在FCOS的檢測部分增加了controller head(實際上controller head是直接加在box head上的),用來產生每個instance的mask head網絡的參數。這個思想其實是CondConv,傳統的Conv訓練完成后是固定的filters,但是CondConv的filters基於input和一個另外的網路來動態產生的。CondInst用來controller head生成instance-aware的mask FCN head,每個instance都有自己獨有的mask head,instance的形狀和大小等信息都編碼在其中。所以當mask head作用在global mask features上時,就可以區分當前的instance和其它背景信息,從而預測出instance mask。
這樣CondInst就可以實現實例分割了,CondInst的正負樣本策略和FCOS一樣,都是通過center region sampling方式來決定正負樣本,其訓練的loss相比FCOS增加intance mask的loss,這個loss也只計算正樣本部分:
Mask Branch
CondInst的mask branch就和語義分割類似是一個FCN網絡,包括4個channel為128的3x3卷積,然后最后接一個channel為8的1x1卷積。mask branch輸入為FPN的P3特征,所以最終產生的特征為原始輸入圖像的1/8,特征channel為8,之所以用一個較小的channel是為了減少controller head所需生成的參數量,而且實驗中發現采用較小的channel就夠了,當channel為2時mask AP僅掉了0.3%。不過從開源的代碼來看,mask branch的輸入應該是來自於FPN的P3,P4和P5,具體實現上先將P4和P5的特征通過雙線性插值,然后和P3加到一起作為mask branch的輸入。就像YOLACT一樣,mask branch產生的特征還可以額外加上語義分割的loss來進行輔助,這個不會影響inference過程,但是實驗上mask AP大約可以提升1個點,具體實現上如下:
# 額外的語義loss,采用focal loss
if self.training and self.sem_loss_on:
logits_pred = self.logits(self.seg_head(
features[self.in_features[0]]
)) # 預測logits,區分class
# 計算語義分割的gt,這里的原則是合並instance的gt mask,但是當不同instance有重疊時,會取面積最小的instance的class作為gt
semantic_targets = []
for per_im_gt in gt_instances:
h, w = per_im_gt.gt_bitmasks_full.size()[-2:]
areas = per_im_gt.gt_bitmasks_full.sum(dim=-1).sum(dim=-1)
areas = areas[:, None, None].repeat(1, h, w)
areas[per_im_gt.gt_bitmasks_full == 0] = INF
areas = areas.permute(1, 2, 0).reshape(h * w, -1)
min_areas, inds = areas.min(dim=1)
per_im_sematic_targets = per_im_gt.gt_classes[inds] + 1
per_im_sematic_targets[min_areas == INF] = 0
per_im_sematic_targets= per_im_sematic_targets.reshape(h, w)
semantic_targets.append(per_im_sematic_targets)
semantic_targets = torch.stack(semantic_targets, dim=0) # [N, 1, H, W]
# 對gt進行降采樣,為原始的1/8
semantic_targets = semantic_targets[:, None, self.out_stride // 2::self.out_stride, self.out_stride // 2::self.out_stride]
# one-hot gt
num_classes = logits_pred.size(1)
class_range = torch.arange(num_classes, dtype=logits_pred.dtype, device=logits_pred.device)[None, :, None, None]
class_range = class_range + 1
one_hot = (semantic_targets == class_range).float()
num_pos = (one_hot > 0).sum().float().clamp(min=1.0)
# 采用focal loss
loss_sem = sigmoid_focal_loss_jit(
logits_pred, one_hot,
alpha=self.focal_loss_alpha,
gamma=self.focal_loss_gamma,
reduction="sum",
) / num_pos
losses['loss_sem'] = loss_sem
return mask_feats, losses
Controller Head
前面說過,CondInst的核心就在於controller head,其用來產生mask head的網絡參數,這個參數是每個instance所獨有的,所以當輸入為全局mask特征時,可以預測出instance mask。由於controller head會編碼instance的形狀和大小信息,所以它是直接加在FCOS的box head上的,就和centerness head一樣。
controller head的輸出channel數為N,恰好是mask head的網絡參數量。mask head采用一個輕量級的FCN網絡,包含三個channel為8的3x3卷積層,卷積之后接ReLU,最后一層卷積直接加上sigmoid(二分類)就可以預測instance mask。所以mask head的參數量N為169:(#weights = (8 + 2) × 8(conv1) + 8 × 8(conv2) + 8 × 1(conv3) and #biases = 8(conv1) + 8(conv2) + 1(conv3))。這里的輸入channel是8+2,而不是8,是因為送入mask head的輸入除了包括,還包含relative coordinates maps,即相對於當前instance的位置(x,y)的相對位置坐標,在實現上只需要把x,y的relative coordinates maps與拼接在一起即可,如果去掉相對位置maps,CondInst性能下降比較厲害,其實也合理,因為controller head產生的mask head參數是由CNN得到的,它雖然可以編碼instance的shape信息,但是不能准確地學習到instance在圖像中的位置,所以加上relative coordinates maps對准確地分割當前instance比較重要。另外論文中的一個有趣的實驗是mask head只輸入relative coordinates maps就可以得到31.3%的mask AP值,這或許說明controller head產生的mask head足夠強大。下面是對controller head的輸出進行結構化解析的代碼:
def parse_dynamic_params(params, channels, weight_nums, bias_nums):
assert params.dim() == 2
assert len(weight_nums) == len(bias_nums)
assert params.size(1) == sum(weight_nums) + sum(bias_nums)
num_insts = params.size(0)
num_layers = len(weight_nums)
params_splits = list(torch.split_with_sizes(
params, weight_nums + bias_nums, dim=1
))
weight_splits = params_splits[:num_layers]
bias_splits = params_splits[num_layers:]
for l in range(num_layers):
if l < num_layers - 1:
# out_channels x in_channels x 1 x 1
weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1)
bias_splits[l] = bias_splits[l].reshape(num_insts * channels)
else:
# out_channels x in_channels x 1 x 1
weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1)
bias_splits[l] = bias_splits[l].reshape(num_insts)
return weight_splits, bias_splits
由於輸入的是1/8圖片大小,所以產生的instance mask也是1/8圖片大小,為了產生高質量的mask,采用雙線性插值對預測的instance mask上采樣4x,那么最終輸出的instance mask就是1/2圖片大小。mask部分的loss采用的是dice loss,它和focal loss一樣可以解決正負樣本不均衡問題,在計算時將gt mask下采樣2x以和預測mask達到同樣的大小。
Inference
CondInst的inference就比較直接了,首先是檢測部分得到檢測的結果,然后采用box-based NMS來去除重復框,最后選出top 100的檢測框,只有這部分instances會進行instance mask的預測。由於產生的mask head非常小,所以100個instance的mask預測時間只需要4.5ms,那么CondInst的預測時間僅比原始的FCOS增加了約10%。這里額外要說的一點是CondInst的box預測主要用於NMS,但不會參與instance mask的預測中,而Mask R-CNN是需要box來進行ROI croping。CondInst和其它實例分割在COCO上的效果對比如下:
此外,CondInst的作者近期又發布了一篇新的不錯的工作:BoxInst,只用box級別的標注就可以訓練出一個不錯的實例分割模型,這個模型也是構建在CondInst上,只不過設計了兩個新的loss來進行半監督式的訓練。最后放一個BoxInst的一個分割視頻demo:
參考
- Conditional Convolutions for Instance Segmentation
- AdelaiDet
- BoxInst: High-Performance Instance Segmentation with Box Annotations