Pytorch學習筆記16----CNN或LSTM模型保存與加載


1.三個核心函數

介紹一系列關於 PyTorch 模型保存與加載的應用場景,主要包括三個核心函數:

(1)torch.save

其中,應用了 Python 的 pickle 包,進行序列化,可適用於模型Models,張量Tensors,以及各種類型的字典對象的序列化保存.

(2)torch.load

采用 Python 的 pickle 的 unpickling 函數,對磁盤 pickled 的對象文件進行反序列化(deserialize),加載到內存.

(3)torch.nn.Module.load_state_dict

采用序列化的 state_dict 加載模型參數(字典).

2.state_dict介紹

PyTorch中,torch.nn.Module 模型中的可學習參數(learnable parameters)(如,weights 和 biases),包含在模型參數(model parameters)里(根據 model.parameters() 進行訪問.)

state_dict可以簡單的理解為 Python 的字典對象,其將每一層映射到其參數張量.

注,只有包含待學習參數的網絡層,如卷積層,線性連接層等,會在模型的 state_dict 中有元素值.

優化器對象(Optimizer object,torch.optim) 也有 state_dict,其包含了優化器的狀態信息,以及所使用的超參數.

由於 state_dict 對象時 Python 字典的形式,因此,便於保存,更新,修改與恢復,有利於 PyTorch 模型和優化器的模塊化.

例如,Training a classifier tutorial 中所使用的簡單模型的 state_dict

# 模型定義
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class ModelNet(nn.Module):
    def __init__(self):
        super(ModelNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 模型初始化
model = ModelNet()

# 優化器初始化
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 打印模型的 state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, 
          "\t", 
          model.state_dict()[param_tensor].size()
         )

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

輸出如下:

Model's state_dict:
conv1.weight      torch.Size([6, 3, 5, 5])
conv1.bias      torch.Size([6])
conv2.weight      torch.Size([16, 6, 5, 5])
conv2.bias      torch.Size([16])
fc1.weight      torch.Size([120, 400])
fc1.bias      torch.Size([120])
fc2.weight      torch.Size([84, 120])
fc2.bias      torch.Size([84])
fc3.weight      torch.Size([10, 84])
fc3.bias      torch.Size([10])

Optimizer's state_dict:
param_groups      [{'weight_decay': 0, 
                   'dampening': 0, 
                   'params': [140448775121872, 140448775121728,
                              140448775121584, 140448775121440,
                              140448775121296, 140448775121152,
                              140448775121008, 140448775120864,
                              140448775120720, 140448775120576],
                   'nesterov': False, 
                   'momentum': 0.9, 
                   'lr': 0.001}]
state      {}

3.模型的保存與加載

(1)保存/加載 state_dict (推薦)

# 模型保存
torch.save(model.state_dict(), PATH)

# 模型加載
model = ModelNet(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

當保存模型,用於推斷時,只有訓練的模型可學習參數是有必要進行保存的.

采用 torch.save() 函數保存模型的 state_dict,對於應用時,模型恢復具有最好的靈活性,因此推薦采用該方式進行模型保存.

PyTorch 通用模型保存格式為 .pt 和 .pth 文件擴展名形式.

需要注意的時,在運行推斷前,需要調用 model.eval() 函數,以將 dropout 層 和 batch normalization 層設置為評估模式(非訓練模式).

注意:

load_state_dict() 函數的輸入是字典形式,而不是對象保存的文件路徑.

也就是說,在將保存的模型文件送入 load_state_dict() 函數前,必須將保存的 state_dict 進行反序列化.

例如,不能直接應用 model.load_state_dict(PATH),而是,load_state_dict(torch.load(PATH)).

(2)保存/加載全部模型信息

# 保存
torch.save(model, PATH)

# 加載
model = ModelNet(*args, **kwargs) # 必須預先定義過模型.
model = torch.load(PATH)
model.eval()

這種方式是最直觀的語法,包含最少的代碼. 其會采用 Python 的pickle 模塊保存全部的模型模塊.

這種方式的缺點在於,序列化的數據受限於在模型保存時所采用的特定的類和准確的路徑結構(specific classes and the exact directory structure). 其原因是,because pickle does not save the model class itself. Rather, it saves a path to the file containing the class, which is used during load time. 因此,再加載后經過許多重構后,或在其它項目中使用時,可能會被打亂.

參考文獻:https://www.aiuai.cn/aifarm743.html

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM