在訓練過程中,往往會遇到中斷,如在
Colab和Kaggle中,由於網絡不穩定,很容易就斷開了連接。然而,即使可以穩定訓練,但是訓練的時長往往是有上限的,此時我們的網絡參數訓練的可能還未收斂仍然需要訓練,所以,應該加載原訓練基礎上再進行訓練是十分很重要的。比如,要訓練1000代才能收斂,但是目前只訓練的100代就中斷了,所以要加載第100代訓練的模型參數,然后訓練接下來的900代
pytorch模型的保存機制
修改訓練代碼
中斷的訓練代碼最簡單的修改方式便是復制一份訓練的代碼,然后在其基礎上進行修改,涉及到最重要的部分就是模型的保存與加載
🅰若優化器optimizer不需要隨着訓練的修改,那么直接加載模型、優化器,之后進行訓練即可
🅱若優化器需要訓練,那么可以進行一下修改:
if epoch == epochs_g + 1:
optimizer_r.load_state_dict(checkpoint_r['optimizer'])
optimizer_g.load_state_dict(checkpoint_g['optimizer'])
lr_r = checkpoint_r['lr']
lr_g = checkpoint_g['lr']
else:
optimizer_r = optim.Adagrad(model_r.parameters(), lr = lr_r, weight_decay = 1e-5)
optimizer_g = optim.Adagrad(model_g.parameters(), lr = lr_g, weight_decay = 1e-5)
- 繼續訓練的第一次是利用模型保存下來的,而之后則是修改的優化器
如:我的模型每訓練50次進行learning rate減半,初始學習率為0.001,而我的模型訓練到第40代中斷,所以加載第40代模型繼續進行訓練
python "train_continue.py" --pre_model_r './LapSRN_r_epoch_40.pt' --pre_model_g './LapSRN_g_epoch_40.pt' --nEpochs 60 --cuda --batchSize 1 --dataset "../../DataSet_test/"
可以看看優化器的變化如下:
Namespace(batchSize=1, cuda=True, dataset='../../DataSet_test/', lr=0.001, nEpochs=60, pre_model_g='./LapSRN_g_epoch_40.pt', pre_model_r='./LapSRN_r_epoch_40.pt', save_models='./', save_train_csv='./train.csv', save_val_csv='/val.csv', seed=123, valBatchSize=1)
===> Loading datasets
===> Loading pre_train model and Building model
Adagrad (
Parameter Group 0
eps: 1e-10
initial_accumulator_value: 0
lr: 0.001
lr_decay: 0
weight_decay: 1e-05
)
===> Epoch 41 Complete: Avg. Loss: 0.0381
===> Avg. PSNR1: 26.2686 dB
===> Avg. PSNR2: 25.1278 dB
Adagrad (
Parameter Group 0
eps: 1e-10
initial_accumulator_value: 0
lr: 0.001
lr_decay: 0
weight_decay: 1e-05
)
===> Epoch 42 Complete: Avg. Loss: 0.0789
===> Avg. PSNR1: 13.8764 dB
===> Avg. PSNR2: 16.7824 dB
.........省略部分..........
Adagrad (
Parameter Group 0
eps: 1e-10
initial_accumulator_value: 0
lr: 0.001
lr_decay: 0
weight_decay: 1e-05
)
===> Epoch 49 Complete: Avg. Loss: 0.0749
===> Avg. PSNR1: 25.5121 dB
===> Avg. PSNR2: 25.1218 dB
Adagrad (
Parameter Group 0
eps: 1e-10
initial_accumulator_value: 0
lr: 0.001
lr_decay: 0
weight_decay: 1e-05
)
===> Epoch 50 Complete: Avg. Loss: 0.0877
===> Avg. PSNR1: 28.2393 dB
===> Avg. PSNR2: 26.6869 dB
Checkpoint saved to ./LapSRN_r_epoch_50.pt and ./LapSRN_g_epoch_50.pt
Adagrad (
Parameter Group 0
eps: 1e-10
initial_accumulator_value: 0
lr: 0.0005
lr_decay: 0
weight_decay: 1e-05
)
===> Epoch 51 Complete: Avg. Loss: 0.2914
===> Avg. PSNR1: 27.3521 dB
===> Avg. PSNR2: 25.3298 dB
Adagrad (
Parameter Group 0
eps: 1e-10
initial_accumulator_value: 0
lr: 0.0005
lr_decay: 0
weight_decay: 1e-05
)
===> Epoch 52 Complete: Avg. Loss: 0.0505
===> Avg. PSNR1: 21.9110 dB
===> Avg. PSNR2: 21.8041 dB
Adagrad (
Parameter Group 0
eps: 1e-10
initial_accumulator_value: 0
lr: 0.0005
lr_decay: 0
weight_decay: 1e-05
)
樣例學習
中斷訓練train_continue.py代碼如下,可供參考學習:
