前提: 模型參數和結構是分別保存的
1、 構建模型(# load model graph)
model = MODEL()
2、加載模型參數(# load model state_dict)
model.load_state_dict
(
{
k.replace('module.',''):v for k,v in
torch.load(config.model_path, map_location=config.device).items()
}
)
model = self.model.to(config.device)
* config.device 指定使用哪塊GPU或者CPU
*k.replace('module.',''):v 防止torch.DataParallel訓練的模型出現加載錯誤
(解決RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cuda:1問題)
3、設置當前階段為inference(# predict)
model.eval()