在模型訓練過程中,一個 epoch 指遍歷一遍訓練集,而一般的模型訓練也是指定多少個 epoch,每個 epoch 結束后看看模型在驗證集上的效果並保存模型。
但在有些場景下,如半監督學習,有標記的樣本很少,一個 epoch 甚至只有一個 batch 的數據,這個時候頻繁查看驗證集效果很耗時。
當數據集很小時,訓練多久用 epoch 表示不太合適,這個時候使用模型更新次數來表示更加合理,每多少個 steps 查看一次驗證集效果並保存模型。
我們可以通過給 DataLoader 傳入一個重復采樣的隨機采樣器 RandomSampler 來實現這個功能,其它代碼和按照 epoch 訓練一致。
# batch_size = 64
# steps_to_save = 1024,每 1024 個 steps 查看驗證集效果並保存模型,相當於一個 epoch 有 1024 個 steps,只是數據有重復罷了。
trainloader = DataLoader(
dataset,
sampler=torch.utils.data.RandomSampler(
dataset,
replacement=True,
num_samples=64*steps_to_save),
batch_size=64,
num_workers=4)