一個簡潔、好用的Pytorch訓練模板
代碼地址:https://github.com/KinglittleQ/Pytorch-Template
怎么使用
1) 更改template.py
替換 __init__方法中的內容,增添自己的模型、優化器、評估器等等.
class Model():
def __init__(self, args):
self.writer = tX.SummaryWriter(log_dir=None, comment='')
self.train_logger = None # not neccessary
self.eval_logger = None # not neccessary
self.args = args # not neccessary
self.step = 0
self.epoch = 0
self.best_error = float('Inf')
self.model = None
self.optimizer = None
self.criterion = None
self.metric = None
self.train_loader = None
self.test_loader = None
self.device = None
self.ckpt_dir = None
self.log_per_step = None
2) 寫部分訓練代碼
你所需要做的只是寫一個簡單的for循環:
model = Model()
for epoch in range(n_epochs):
model.train()
if (epoch + 1) % eval_per_epoch == 0:
model.eval()
print('Done!!!')
3) 繼續訓練
繼續訓練十分方便,只需要加載之前保存好的模型。
model = Model()
if model_path:
model.load_state(model_path)
for i in range(n_epochs):
model.train()
if model.epoch % eval_per_epoch == 0:
model.eval()
Example
-
LeNet: 訓練一個LeNet對MNIST手寫數字進行分類
-
訓練過程如下:
...... epoch 1 step 3400 loss 0.0434 epoch 1 step 3500 loss 0.0331 epoch 1 step 3600 loss 0.00188 epoch 1 step 3700 loss 0.00341 save model at ../models\best.pth.tar save model at ../models\1.pth.tar epoch 1 error 0.0237 epoch 2 step 3800 loss 0.0201 epoch 2 step 3900 loss 0.00523 epoch 2 step 4000 loss 0.0236 ...... -
使用tensorboard可視化輸出:
tensorboard --logdir example/LeNet/log

-
繼續訓練
load model from checkpoint/9.pth.tar epoch 10 step 33800 loss 0.000128 epoch 10 step 33900 loss 6.64e-06 epoch 10 step 34000 loss 0.000613 epoch 10 step 34100 loss 2.41e-05 ......
-
