數據增強
transforms是pytorch中用於數據增強的模塊,首先再簡單描述下數據增強的概念:
數據增強又稱為數據增廣,數據擴增,它是對訓練集進行變換,使訓練集更豐富,從而讓模型根據泛化能力
舉個非常生動形象的例子,五年高考三年模擬相信大家都知道,其實這就是一個學習模型,其中的三年模擬試題就是數據增強
具體的各類方法
裁剪 crop
transforms.CenterCrop
功能:從圖像中心裁剪圖片
size:所需裁剪的圖片尺寸(如果size比原來的圖尺寸大,則會在原圖周圍填充黑色帶至此size)
或者
transforms.RandomCrop
功能:從圖片中隨機裁剪出尺寸為size的圖片
size:所需裁剪圖片尺寸
padding:設置填充大小
- 當為a時,上下左右均填充a個像素
- 當為(a,b)時,上下填充b個像素,左右填充a個像素
- 當為(a,b,c,d)時,左上右下各填充a、b、c、d個像素
pad_if_need:若圖像小於設定size,則填充
padding_mode:填充模式,有4種模式
- constant:像素值由fill設定
- edge:像素值由圖像邊緣像素決定
- reflect:鏡像填充,最后一個像素不鏡像
- symmetric:鏡像填充,最后一個像素鏡像
fill:constant時,設置填充的像素值
transforms.RandomResizedCrop
功能:隨機大小、長寬比裁剪圖片
size:所需裁剪圖片
scale:隨機裁剪面積比例,默認(0.08, 1)
ratio:隨機長寬比,默認(3/4, 4/3)
interpolation:插值方法
- PIL.Image.NERAREST
- PIL.Image.BILINEAR
- PIL.Image.BICUBIC
transforms.FiveCrop
transforms.TenCrop
功能:在圖像的上下左右以及中心裁剪出尺寸為size的5張圖片,
TenCrop對這5張圖片進行水平或者垂直鏡像獲得10張圖片
size:所需裁剪圖片尺寸
vertical_filp:是否垂直翻轉(TenCrop)
翻轉和旋轉 flip and rotation
transforms.RandomHorizontalFlip
transforms.RandomVerticalFlip
功能:依概率水平(左右)或垂直(上下)翻轉圖片
p:翻轉概率
transforms.RandomRotation
功能:隨機旋轉圖片
degrees:旋轉角度
- 當為a時,在(-a, a)之間選擇旋轉角度
- 當為(a, b)時,在(-a, b)之間選擇旋轉角度
resample:重采樣方法
expand:是否擴大圖片,以保持原來所有的圖片信息(注意:如果使用expand,會擴大原來的圖片size,當處理的batch_size不為1時,無法將之后圖片統一形成batch,所以使用后,需要resize成原來的大小)
center:旋轉點設置,默認中心旋轉
圖像變換 image transforms
transforms.Pad
功能:對圖像邊緣進行填充
padding:設置填充大小
- 當為a時,上下左右均填充a個像素
- 當為(a,b)時,上下填充b個像素,左右填充a個像素
- 當為(a,b,c,d)時,左上右下各填充a、b、c、d個像素
padding_mode:填充模式,有4種模式
- constant:像素值由fill設定
- edge:像素值由圖像邊緣像素決定
- reflect:鏡像填充,最后一個像素不鏡像
- symmetric:鏡像填充,最后一個像素鏡像
fill:constant時,設置填充的像素值,(R,G,B)or (Gray)【可以理解為用具體數值表示的color】
transforms.ColorJitter
功能:調整亮度、對比度、飽和度和色相
brightness:亮度調整因子
- 當為a時,從[max(0, 1-a), 1+a]中隨機選擇
- 當為(a, b)時,從[a, b]中隨機選擇
contrast:對比度參數,同brightness
對比度:調高:偏黑的更黑,偏白的更白;調低:發灰,因為黑白都趨近中間值
saturation:飽和度(強調色彩鮮艷或黯淡)參數,同brightness
hue:色相參數,改變圖像原來的色彩,顯得“違和”
- 當為a時,從[-a, a]中隨機選擇參數,注:0 <= a <= 0.5
- 當為(a, b)時,從[a, b]中隨機選擇參數,注:-0.5 <= a <= b <= 0.5
transforms.Grayscale
transforms.RandomGrayscale
功能:依概率將圖像轉換為灰度圖
num_output_channels:輸出通道數,只能設1或3
p:圖像轉為灰度圖的概率,Grayscale即為RandomGrayscale的p為1時
transforms.RandomAffine
功能:對圖像進行仿射變換。仿射變換是二維的線性變換,由五種基本原子構成,分別是旋轉、平移、縮放、錯切和翻轉
degrees:旋轉角度設置
translate:平移區間設置,如(a, b),a設置寬(width),b設置高(height),圖像在寬維度平移的區間為±img_width*a之間
scale:縮放比例(以面積為單位)
fill_color:填充顏色設置
shear:錯切角度設置,有水平錯切和垂直錯切(類似於平行四邊形似的拉伸)
- 若為a,則僅在x軸錯切,錯切角度在(-a,a)之間
- 若為(a,b),則a設置x軸角度,b設置y軸角度
- 若為(a,b,c,d),則a,b設置x軸角度,c,d設置y軸角度
resample:重采樣方式,即三類插值方式
transforms.Erasing
功能:對圖像進行隨機的遮擋,注意:輸入對象為Tensor而不是PIL,需先使用ToTensor()
p:執行該操作的概率
scale:遮擋區域的面積
ratio:遮擋區域的長寬比
value:設置遮擋區域的像素值(顏色),(R,G,B)或(Gray)而且此時對象為tensor,所以應該設置歸一化值;如果設置value的值為字符串,則遮擋區域的顏色會變成彩色(雪花屏)
(具體算法可見論文《Data Augmentation Random Erasing》)
transforms.Lambda
功能:用戶自定義lambda方法,用於簡單實現函數功能
lambd:填寫lambda匿名函數表達式
格式:lambda [arg1 arg2... argn] :expression (arg為input,expression為進行的處理)
對transform方法的操作
transforms.RandomChoice
功能:從一系列transforms方法中隨機挑選一個執行
transforms.RandomChoice([transforms1, transforms2, transforms3...])
transforms.RandomApply
功能:依據概率執行一組transforms操作
transforms.RandomApply([transforms1, transforms2, transforms3...],p=0.5)
transforms.RandomOrder
功能:對一組transforms操作打亂順序后執行
transforms.RandomOrder([transforms1, transforms2, transforms3...])
自定義transforms方法
從compose源碼中可以得到,transforms有一些固定的收參與格式:
1.僅接受一個參數,返回一個參數
2.注意上下游的輸出和輸入
通過類實現多參數傳入,下為自定義方法的基本結構:
其中init函數指定需要的參數,如概率、信噪比等等,call函數就是調用時,所執行的具體操作
下面以椒鹽噪聲的transforms方法自定義實現來為例:
椒鹽噪聲又稱脈沖噪聲,是一種隨機出現的白點或黑點,白點稱為鹽噪聲,黑點稱為椒噪聲
其主要參數有信噪比(signal-noise rate,SNR),用以衡量噪聲的比例,在圖像中為圖像像素的占比
由上述模板可以構建函數思路:
附上代碼可進一步理解
class AddPepperNoise(object):
"""增加椒鹽噪聲
Args:
snr (float): Signal Noise Rate
p (float): 概率值,依概率執行該操作
"""
def __init__(self, snr, p=0.9):
assert isinstance(snr, float) or (isinstance(p, float))
self.snr = snr
self.p = p
def __call__(self, img):
"""
Args:
img (PIL Image): PIL Image
Returns:
PIL Image: PIL image.
"""
if random.uniform(0, 1) < self.p:
img_ = np.array(img).copy()
h, w, c = img_.shape
signal_pct = self.snr
noise_pct = (1 - self.snr)
mask = np.random.choice((0, 1, 2), size=(h, w, 1), p=[signal_pct, noise_pct/2., noise_pct/2.])
mask = np.repeat(mask, c, axis=2)
img_[mask == 1] = 255 # 鹽噪聲
img_[mask == 2] = 0 # 椒噪聲
return Image.fromarray(img_.astype('uint8')).convert('RGB')
else:
return img
總之,數據增強是為了使得訓練集更好的接近測試集,針對測試集的一些特點或是易混點來選擇有效的transforms方法加以突出或是消除,以達到更好的結果。
一個簡單例子,原模型的訓練數據為第四套一元和一百元RMB,而現在測試的是第五套一百元人民幣,如果不加變換,大概率會識別成一元,因為第四套一元和第五套一百元顏色很接近。這時如果做一個灰度變換,測試結果就會識別正確,判定為100元。