pytorch-模型保存與加載自己訓練的模型詳解 


filename = 'cvae_' + str(epoch+1) + '.pkl'
save_path = save_dir / Path(filename)
states = {}
states['model'] = cvae.state_dict() # 模型參數
states['z_dim'] = args.z_dim
states['x_dim'] = args.x_dim
states['s_dim'] = args.s_dim
states['optim'] = cvae.state_dict()
torch.save(states, str(save_path)) #檢查點:將states字典存放在save_path文件下 
保存和加載模型的時候,配對的函數:
對於僅保存state_dict()的方式,那保存和加載模型的方式為:
保存:torch.save(model.state_dict(), PATH)
加載:model.laod_state_dict(torch.load(PATH))
一般加載模型是在訓練完成后用模型做測試,這時候加載模型記得要加上model.eval(),把模型切換到evaluation模式,這時候會調整dropout和bactch的模式。

對於保存和加載整個模型的情況:
torch.save(model, PATH)
model = torch.load(PATH)
可以看到,前面的model.load_state_dict()和這里的不同,前面的情況需要你先定義一個模型,然后再load_state_dict()
但是這里load整個模型,會把模型的定義一起load進來。完成了模型的定義和加載參數的兩個過程。
詳細代碼
 1     def save(self):  2         save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)  3 
 4         if not os.path.exists(save_dir):  5  os.makedirs(save_dir)  6 
 7         torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl'))  8         torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl'))  9 
10         with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: 11  pickle.dump(self.train_hist, f) 12 # 使用方法:對模型初始化以后,使用以下方法,加載模型的參數,以至於不用再對數據集進行訓練
13     def load(self): 14         save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 15 
16         self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) 17         self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))

note:

pickle.dump(obj, file, [,protocol]) 序列化對象,將對象obj保存到文件file中去。self.train_hist用於存放模型文件

pickle.load(file) 反序列化對象,將文件中的數據解析為一個python對象。file中有read()接口和readline()接口




免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM