filename = 'cvae_' + str(epoch+1) + '.pkl'
save_path = save_dir / Path(filename)
states = {}
states['model'] = cvae.state_dict() # 模型參數
states['z_dim'] = args.z_dim
states['x_dim'] = args.x_dim
states['s_dim'] = args.s_dim
states['optim'] = cvae.state_dict()
torch.save(states, str(save_path)) #檢查點:將states字典存放在save_path文件下
保存和加載模型的時候,配對的函數:
對於僅保存state_dict()的方式,那保存和加載模型的方式為:
保存:torch.save(model.state_dict(), PATH)
加載:model.laod_state_dict(torch.load(PATH))
一般加載模型是在訓練完成后用模型做測試,這時候加載模型記得要加上model.eval(),把模型切換到evaluation模式,這時候會調整dropout和bactch的模式。
對於保存和加載整個模型的情況:
torch.save(model, PATH)
model = torch.load(PATH)
可以看到,前面的model.load_state_dict()和這里的不同,前面的情況需要你先定義一個模型,然后再load_state_dict()
但是這里load整個模型,會把模型的定義一起load進來。完成了模型的定義和加載參數的兩個過程。
詳細代碼
1 def save(self): 2 save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 3
4 if not os.path.exists(save_dir): 5 os.makedirs(save_dir) 6
7 torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) 8 torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) 9
10 with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: 11 pickle.dump(self.train_hist, f) 12 # 使用方法:對模型初始化以后,使用以下方法,加載模型的參數,以至於不用再對數據集進行訓練
13 def load(self): 14 save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 15
16 self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) 17 self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))
note:
pickle.dump(obj, file, [,protocol]) 序列化對象,將對象obj保存到文件file中去。self.train_hist用於存放模型文件
pickle.load(file) 反序列化對象,將文件中的數據解析為一個python對象。file中有read()接口和readline()接口