https://zhuanlan.zhihu.com/p/342105673
特征處理部分比較好理解,點的self、cross注意力機制實現建議看下源碼(MultiHeadedAttention),
def attention(query, key, value): dim = query.shape[1] scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5 prob = torch.nn.functional.softmax(scores, dim=-1) return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob class MultiHeadedAttention(nn.Module): """ Multi-head attention to increase model expressivitiy """ def __init__(self, num_heads: int, d_model: int): super().__init__() assert d_model % num_heads == 0 self.dim = d_model // num_heads self.num_heads = num_heads self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) def forward(self, query, key, value): batch_dim = query.size(0) query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) for l, x in zip(self.proj, (query, key, value))] x, prob = attention(query, key, value) self.prob.append(prob) return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))
這里直接跳到最后的邏輯部分,這部分論文寫的比較粗略,需要看下源碼才知道在講啥(也許有大佬不用看)。
看這里,即是說推理時檢出的匹配關系是不考慮最后一行和最后一列的,而是設定閾值,將不合格的匹配過濾掉
# Get the matches with score above "match_threshold". max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) indices0, indices1 = max0.indices, max1.indices mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) # [0,0...,1,..0] mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) zero = scores.new_tensor(0) mscores0 = torch.where(mutual0, max0.values.exp(), zero) mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) valid0 = mutual0 & (mscores0 > self.config['match_threshold']) valid1 = mutual1 & valid0.gather(1, indices1) indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
推理時代碼如下,可見圖A和圖B互相匹配的結果(按照score的行列取最大值的index)不必完全一致:
kpts0, kpts1 = pred['keypoints0'].cpu().numpy()[0], pred['keypoints1'].cpu().numpy()[0] matches, conf = pred['matches0'].cpu().detach().numpy(), pred['matching_scores0'].cpu().detach().numpy() image0 = read_image_modified(image0, opt.resize, opt.resize_float) image1 = read_image_modified(image1, opt.resize, opt.resize_float) valid = matches > -1 mkpts0 = kpts0[valid] mkpts1 = kpts1[matches[valid]] mconf = conf[valid]
然后看求解分配矩陣的部分,couplings為相似度得分矩陣,為其添加了最后一行一列,並賦值為1,注意,這一行一列的(m+n-1)個值實際對應的是同意個內存區域,初始值為1,是可以學習的,在原文提到的約束下,使用sinkhorn(待看)算法求解,求出分配矩陣Z,
# b(m+1)(n+1), b(m+1), b(n+1) def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int): """ Perform Sinkhorn Normalization in Log-space for stability""" u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu) for _ in range(iters): # [log(m+n) ..., log(n)+log(m+n)] - [] u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2) # b(m+1) v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1) return Z + u.unsqueeze(2) + v.unsqueeze(1) def log_optimal_transport(scores, alpha, iters: int): """ Perform Differentiable Optimal Transport in Log-space for stability""" b, m, n = scores.shape one = scores.new_tensor(1) ms, ns = (m*one).to(scores), (n*one).to(scores) bins0 = alpha.expand(b, m, 1) # only a new view bins1 = alpha.expand(b, 1, n) alpha = alpha.expand(b, 1, 1) # b(m+1)(n+1), 額外行列值為1 couplings = torch.cat([torch.cat([scores, bins0], -1), # bmn,bm1->bm(n+1) torch.cat([bins1, alpha], -1)], 1) # b1n,b11->b1(n+1) norm = - (ms + ns).log() log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) # m+1: [log(m+n) ..., log(n)+log(m+n)] log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm]) # n+1: [log(m+n) ..., log(m)+log(m+n)] log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1) # b(m+1), b(n+1) Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters) Z = Z - norm # multiply probabilities by M+N return Z
損失函數就是最大化這個分配矩陣Z,即下面的scores矩陣,匹配對中肯定不包含dustbin點的,也就是說對dustbin的考量蘊含在sinkhorn中,注意下面的函數調用的參數self.bin_score,這是superglue網絡的一個可學習的參數:
all_matches = data['all_matches'].permute(1,2,0) # shape=torch.Size([1, 87, 2]) …… # Run the optimal transport. scores = log_optimal_transport( scores, self.bin_score, iters=self.config['sinkhorn_iterations']) …… # check if indexed correctly loss = [] for i in range(len(all_matches[0])): x = all_matches[0][i][0] y = all_matches[0][i][1] loss.append(-torch.log( scores[0][x][y].exp() )) # check batch size == 1 ?
損失函數部分很好理解,按照公式推測上面的all matches里的匹配值應該是包含無匹配點的(存疑),例如i匹配J+1這樣,否則體現不出來損失函數的后兩項:
原文里對分配矩陣的約束如下,
個人理解這里的a、b對應的N、M應該是打錯了,代表的是A、B兩圖中的無匹配點,對Sinkhorn算法而言,湊齊質量守恆條件即可應用,作者在這里對分配矩陣P_head進行了湊項,相應的對代價矩陣S也要湊,為了湊S,作者采用了上面代碼講解中提到的很奇怪的單個數值tensor內存映射成一行&一列的格式,作者原文對這里的講解就很簡略,感覺就是試了下這樣湊代價矩陣,發現挺好用,沒有什么其他道理。
說句題外話,這個S確實不好湊,每對圖特征點排列完全隨機,額外行列處每個位置一個變量也沒什么道理,統一用一個變量反而有一種“閾值”的感覺,雖然推理是對應計算的分配矩陣P_head的額外行列直接扔掉了。
相對應的,P的約束就很好理解: