pytorch 狀態字典:state_dict 模型和參數保存


pytorch 中的 state_dict 是一個簡單的python的字典對象,將每一層與它的對應參數建立映射關系.(如model的每一層的weights及偏置等等)

(注意,只有那些參數可以訓練的layer才會被保存到模型的state_dict中,如卷積層,線性層等等)

優化器對象Optimizer也有一個state_dict,它包含了優化器的狀態以及被使用的超參數(如lr, momentum,weight_decay等)

 

備注:

1) state_dict是在定義了model或optimizer之后pytorch自動生成的,可以直接調用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)
2) load_state_dict 也是model或optimizer之后pytorch自動具備的函數,可以直接調用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因為,只有在執行該命令后,"dropout層"及"batch normalization層"才會進入 evalution 模態. 而在"訓練(training)模態"與"評估(evalution)模態"下,這兩層有不同的表現形式.

-------------------------------------------------------------------------------------------------------------------------------

模態字典(state_dict)的保存(model是一個網絡結構類的對象)

1.1)僅保存學習到的參數,用以下命令

    torch.save(model.state_dict(), PATH)

1.2)加載model.state_dict,用以下命令

    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.eval()

    備注:model.load_state_dict的操作對象是 一個具體的對象,而不能是文件名

-----------

2.1)保存整個model的狀態,用以下命令

    torch.save(model,PATH)

2.2)加載整個model的狀態,用以下命令:

          # Model class must be defined somewhere

    model = torch.load(PATH)

    model.eval()

--------------------------------------------------------------------------------------------------------------------------------------

state_dict 是一個python的字典格式,以字典的格式存儲,然后以字典的格式被加載,而且只加載key匹配的項

----------------------------------------------------------------------------------------------------------------------

如何僅加載某一層的訓練的到的參數(某一層的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
--------------------------------------------------------------------------------------------

加載模型參數后,如何設置某層某參數的"是否需要訓練"(param.requires_grad)

for param in list(model.pretrained.parameters()):
param.requires_grad = False
注意: requires_grad的操作對象是tensor.

疑問:能否直接對某個層直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:經測試,不可以.model.conv1 沒有requires_grad屬性.

 

---------------------------------------------------------------------------------------------

全部測試代碼:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim



# define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass,self).__init__()
self.conv1 = nn.Conv2d(3,6,5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)

def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1,16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

# initial model
model = TheModelClass()

#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor,'\t',model.state_dict()[param_tensor].size())

print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
print(var_name,'\t',optimizer.state_dict()[var_name])

print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)

print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 僅僅加載某一層的參數
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)

model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)
---------------------
作者:wzg2016
來源:CSDN
原文:https://blog.csdn.net/strive_for_future/article/details/83240081
版權聲明:本文為博主原創文章,轉載請附上博文鏈接!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM