0705-深度網絡模型持久化


0705-深度網絡模型持久化

pytorch完整教程目錄:https://www.cnblogs.com/nickchen121/p/14662511.html

一、持久化概述

在 torch 中,以下對象可以持久化到硬盤,並可以通過相應的方法把這些對象持久化到內存中:

  • Tensor
  • Variable
  • nn.Module
  • Optimizer

上述對象本質上最后都是保存為 Tensor。並且 Tensor 的保存和加載非常簡單,使用 t.savet.load 即可。

在 save/load 時可指定使用的 pickle 模塊,在 load 時還可以把 GPU tensor 映射到 CPU 或者其他 GPU 上。

我們可以通過 t.save(obj, file_name) 保存任意可序列化的對象,然后通過 obj=t.load(file_name) 方法加載保存的數據。

對於 Module 和 Optimizer 對象,建議保存為對應的 state_dict,而不是直接保存整個 Module/Optimizer 對象。Optimizer 對象保存的是參數和動量信息,通過加載之前的動量信息,能夠很有效地減少模型震盪。

二、tensor 對象的保存和加載

import torch as t

a = t.Tensor(3, 4)
if t.cuda.is_available():
    a = a.cuda(1)  # 把 a 轉為 GPU1 上的 tensor
    t.save(a, 'a.pth')

    # 加載為 b,存儲於 GPU1 上(因為保存時 tensor 就在 GPU1 上)
    b = t.load('a.pth')

    # 加載為 c,存儲於 CPU
    c = t.load('a.pth', map_location=lambda storage, loc: storage)

    # 加載為 d,存儲於 GPU0 上
    d = t.load('a.pth', map_location={'cuda:1': 'cuda:0'})

三、Module 對象的保存和加載

t.set_default_tensor_type('torch.FloatTensor')
from torchvision.models import AlexNet

model = AlexNet()
# module 的 state_dict 是一個字典
model.state_dict().keys()

t.save(model.state_dict(), 'alexnet.pth')
model.load_state_dict(t.load('alexnet.pth'))
<All keys matched successfully>

四、Optimizer 對象的保存和加載

optimizer = t.optim.Adam(model.parameters(), lr=0.1)
t.save(optimizer.state_dict(), 'optimizer.pth')
optimizer.load_state_dict(t.load('optimizer.pth'))

五、所有對象集合的保存和加載

all_data = dict(optimizer=optimizer.state_dict(),
                model=model.state_dict(),
                info=u'模型和優化器的所有參數')
t.save(all_data, 'all.pth')

all_data = t.load('all.pth')
all_data.keys()
dict_keys(['optimizer', 'model', 'info'])

六、第七章總結

本章介紹了 torch 的很多工具模塊,主要涉及數據加載、可視化和 GPU 加速相關的內容,合理地使用這些模塊可以極大地提升我們的編碼效率。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM