之前一直自己手寫各種triphard,triplet損失函數, 寫的比較暴力,然后今天一個學長給我在github上看了一個別人的triphard的寫法,一開始沒看懂,用的pytorch函數沒怎么見過,看懂了之后, 被驚艷到了。。因此在此記錄一下,以及詳細注釋一下
class TripletLoss(nn.Module):
def __init__(self, margin=0.3):
super(TripletLoss, self).__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin) # 獲得一個簡單的距離triplet函數
def forward(self, inputs, labels):
n = inputs.size(0) # 獲取batch_size
# Compute pairwise distance, replace by the official when merged
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) # 每個數平方后, 進行加和(通過keepdim保持2維),再擴展成nxn維
dist = dist + dist.t() # 這樣每個dis[i][j]代表的是第i個特征與第j個特征的平方的和
dist.addmm_(1, -2, inputs, inputs.t()) # 然后減去2倍的 第i個特征*第j個特征 從而通過完全平方式得到 (a-b)^2
dist = dist.clamp(min=1e-12).sqrt() # 然后開方
# For each anchor, find the hardest positive and negative
mask = labels.expand(n, n).eq(labels.expand(n, n).t()) # 這里dist[i][j] = 1代表i和j的label相同, =0代表i和j的label不相同
dist_ap, dist_an = [], []
for i in range(n):
dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) # 在i與所有有相同label的j的距離中找一個最大的
dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) # 在i與所有不同label的j的距離找一個最小的
dist_ap = torch.cat(dist_ap) # 將list里的tensor拼接成新的tensor
dist_an = torch.cat(dist_an)
# Compute ranking hinge loss
y = torch.ones_like(dist_an) # 聲明一個與dist_an相同shape的全1tensor
loss = self.ranking_loss(dist_an, dist_ap, y)
return loss