pytorch的數據增強功能並非是事先對整個數據集進行數據增強處理,而是在從dataloader中獲取訓練數據的時候(獲取每個epoch的時候)才進行數據增強。
舉個例子,如下面的數據增強代碼:
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4), # 對圖像四周各填充4個0像素,然后隨機裁剪成32*32
transforms.RandomHorizontalFlip(), # 按0.5的概率水平翻轉圖片
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
假設數據集一共有100張圖片,pytorch並非對數據集中的每張圖片進行隨機裁剪,再隨機翻轉,將數據集擴增到200張,然后用這固定的200張圖來訓練網絡,這是錯誤的理解。
正確的理解應該是dataloader在每次生成epoch時才對數據集進行以上數據增強操作。由於數據增強有些操作是具有隨機性的(例如上面的隨機裁剪和隨機翻轉),導致每次epoch產生的數據都不相同,例如同一張圖片在有的epoch翻轉了,在有的epoch沒有翻轉,或者同一張圖片在各個epoch裁剪的位置不一樣,所以每次用來訓練的數據不相同,到達了數據增強的目的。
當然,有些數據增強操作不具有隨機性,如CenterCrop,每次都是對圖片中間位置進行裁剪,不管在哪個epoch,裁剪出來的圖片都一樣。
————————————————
版權聲明:本文為CSDN博主「Mr.Jcak」的原創文章,遵循CC 4.0 BY-SA版權協議,轉載請附上原文出處鏈接及本聲明。
原文鏈接:https://blog.csdn.net/weixin_38314865/article/details/104318112