CutMix
CutMix是在論文《CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features》被提出的數據增強方式,常用於分類任務和檢測任務。
什么是CutMix
Cut指切割出圖片中的一小塊,MIx指將這一小塊貼到其他圖片中,並且label也會進行混合。
從下圖可以看出CutMix對模型分類准確率和定位准確率有明顯的提升。

CutMix的操作可以用如下公式表示:
\[\begin{align} \bar x &= M \odot x_A + (1-M)\odot x_B \\ \bar y &= \lambda y_A + (1-\lambda)y_B \end{align}\]
其中的符號解釋如下:
- \(M\)是一個二值Mask。對於\(x_A\),\(M=1\)部分的圖像會被保留。對於\(x_B\),\(M=0\)的部分會被保留
- \(x_A,x_B\)分別是兩張圖片
- \(y_A,y_B\)是對應的label
- \(\bar x ,\bar y\)是CutMix后的圖像
- \(\odot\)表示按元素相乘
- \(\lambda\)和Mixup中的一樣,服從\(\beta(\alpha,\alpha)\)分布(論文中設置\(\alpha\)為1)
Mask的生成
\(M\)的取值是隨機生成一個bounding box來得到的,這個bbox的參數為\(B=(r_x,r_y,r_w,r_y)\),通過下面公式計算得到
\[r_x \sim \text{Unif}(0,W) \\ r_y \sim \text{Unif}(0,H) \\ r_w=W\sqrt{1-\lambda} \\ r_h=H\sqrt{1-\lambda}\]
\(M\)這個矩陣的大小和圖像一樣,bbox內的值為0,其他值為1。
label的融合
當前圖片內容在融合后面積的占比決定了label的值,假設分別用兩張圖的30%和70%融合在一起,原始label分別是\([1,0]\)和\([0,1]\),則融合label為\([0.3,0.7]\)
從上面公式可以計算出生成的bbox大小為
\[r_w*r_h= W\sqrt{1-\lambda}*H\sqrt{1-\lambda} =WH(1-\lambda)\]
bbox和原圖的面積比例就為
\[WH(1-\lambda)/(WH) = 1-\lambda \]
從公式(1)可以看出圖A保留了bbox以外的部分,因此\(y_A\)的系數為\(\lambda\)。
代碼實現
代碼實現中有一些不同的是,生成bbox的中心點是在全圖范圍隨機,如果中心點靠近圖像邊緣,那么bbox的面積和原圖的比可能就不是\(1-\lambda\)。因此這個面積比例是重新計算的。
圖像之間的對應關系是隨機的,有可能對應到自己本身,就不會進行cutmix,多執行幾次能看到效果。
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams['figure.figsize'] = [10, 10]
import cv2
def rand_bbox(size, lamb):
"""
生成隨機的bounding box
:param size:
:param lamb:
:return:
"""
W = size[0]
H = size[1]
# 得到一個bbox和原圖的比例
cut_ratio = np.sqrt(1.0 - lamb)
cut_w = int(W * cut_ratio)
cut_h = int(H * cut_ratio)
# 得到bbox的中心點
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
def cutmix(image_batch, image_batch_labels, alpha=1.0):
# 決定bbox的大小,服從beta分布
lam = np.random.beta(alpha, alpha)
# permutation: 如果輸入x是一個整數,那么輸出相當於打亂的range(x)
rand_index = np.random.permutation(len(image_batch))
# 對應公式中的y_a,y_b
target_a = image_batch_labels
target_b = image_batch_labels[rand_index]
# 根據圖像大小隨機生成bbox
bbx1, bby1, bbx2, bby2 = rand_bbox(image_batch[0].shape, lam)
image_batch_updated = image_batch.copy()
# image_batch的維度分別是 batch x 寬 x 高 x 通道
# 將所有圖的bbox對應位置, 替換為其他任意一張圖像
# 第一個參數rand_index是一個list,可以根據這個list里索引去獲得image_batch的圖像,也就是將圖片亂序的對應起來
image_batch_updated[:, bbx1: bbx2, bby1:bby2, :] = image_batch[rand_index, bbx1:bbx2, bby1:bby2, :]
# 計算 1 - bbox占整張圖像面積的比例
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1)) / (image_batch.shape[1] * image_batch.shape[2])
# 根據公式計算label
label = target_a * lam + target_b * (1. - lam)
return image_batch_updated, label
if __name__ == '__main__':
cat = cv2.cvtColor(cv2.imread("data/neko.png"), cv2.COLOR_BGR2RGB)
dog = cv2.cvtColor(cv2.imread("data/inu.png"), cv2.COLOR_BGR2RGB)
updated_img, label = cutmix(np.array([cat, dog]), np.array([[0, 1], [1, 0]]), 0.5)
print(label)
fig, axs = plt.subplots(nrows=1, ncols=2, squeeze=False)
ax1 = axs[0, 0]
ax2 = axs[0, 1]
ax1.imshow(updated_img[0])
ax2.imshow(updated_img[1])
plt.show()
