轉載請注明出處:
http://www.cnblogs.com/darkknightzh/p/8108466.html
參考網址:
http://pytorch.org/docs/master/notes/serialization.html
https://github.com/clcarwin/sphereface_pytorch
有兩種方式保存和載入模型
1. 只保存和載入模型參數
保存:
torch.save(the_model.state_dict(), PATH)
載入:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
當model使用gpu訓練時,可以將數據轉換到cpu中,並保存(載入時,還是上面的方法。需要使用gpu時,加上.cuda()):
def save_model(model, filename): state = model.state_dict() for key in state: state[key] = state[key].clone().cpu() torch.save(state, filename)
2. 保存和載入整個模型
保存:
torch.save(the_model, PATH)
載入:
the_model = torch.load(PATH)
However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.
第二種方式,序列化后的數據使用特殊的結構,缺點就是當在其他工程中使用時,可能會碰到各種問題。
因而,官方更建議使用第一種方式。