(原+譯)pytorch中保存和載入模型


轉載請注明出處:

http://www.cnblogs.com/darkknightzh/p/8108466.html

參考網址:

http://pytorch.org/docs/master/notes/serialization.html

https://github.com/clcarwin/sphereface_pytorch

 

有兩種方式保存和載入模型

1. 只保存和載入模型參數

保存:

torch.save(the_model.state_dict(), PATH)

載入:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

當model使用gpu訓練時,可以將數據轉換到cpu中,並保存(載入時,還是上面的方法。需要使用gpu時,加上.cuda()):

def save_model(model, filename):
    state = model.state_dict()
    for key in state: state[key] = state[key].clone().cpu()
    torch.save(state, filename)

2. 保存和載入整個模型

保存:

torch.save(the_model, PATH)

載入:

the_model = torch.load(PATH)

However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.

第二種方式,序列化后的數據使用特殊的結構,缺點就是當在其他工程中使用時,可能會碰到各種問題。

因而,官方更建議使用第一種方式。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM