pytorch 斷點續訓練


checkpoint  = torch.load('.pth')
    net.load_state_dict(checkpoint['net'])
    criterion_mse = torch.nn.MSELoss().to(cfg.device)
    criterion_L1 = L1Loss()
    optimizer = torch.optim.Adam([paras for paras in net.parameters() if paras.requires_grad == True], lr=cfg.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.n_steps, gamma=cfg.gamma)
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict= checkpoint['lr_schedule']
    start_epoch = checkpoint['epoch']

 for idx_epoch in range(start_epoch+1,80):
        scheduler.step()
        for idx_iter, () in enumerate(train_loader):
           

            _ = net()

         
            loss = criterion_mse(,)

            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

           if idx_epoch % 1 == 0:
           
     
            checkpoint = {
                "net": net.state_dict(),#網絡參數
                'optimizer': optimizer.state_dict(),#優化器
                "epoch": idx_epoch,#訓練輪數
                'lr_schedule': scheduler.state_dict()#lr如何變化
            }
            torch.save(checkpoint,os.path.join(save_path, filename))
           

 

直接訓練
a mean psnr:  28.160327919812364
a mean ssim:  0.8067064184409644
b mean psnr:  25.01364162100755
b mean ssim:  0.7600019779915981
c mean psnr:  25.83471135230011
c mean ssim:  0.7774989383731079

斷點續訓
a mean psnr:  28.15391601255439
a mean ssim:  0.8062857339309237
b mean psnr:  25.01115760689137
b mean ssim:  0.7596963993692107
c mean psnr:  25.842269038618145
c mean ssim:  0.7772710729947427

 

斷點續訓的效果基本和直接訓練一致,但仍有些差,后面會繼續分析


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM