在pytorch中,torch.nn.Module模塊中的state_dict變量存放訓練過程中需要學習的權重和偏執系數,state_dict作為python的字典對象將每一層的參數映射成tensor張量,需要注意的是torch.nn.Module模塊中的state_dict只包含卷積層和全連接層的參數,當網絡中存在batchnorm時,例如vgg網絡結構,torch.nn.Module模塊中的state_dict也會存放batchnorm's running_mean,關於batchnorm詳解可見https://blog.csdn.net/wzy_zju/article/details/81262453
torch.optim模塊中的Optimizer優化器對象也存在一個state_dict對象,此處的state_dict字典對象包含state和param_groups的字典對象,而param_groups key對應的value也是一個由學習率,動量等參數組成的一個字典對象。
因為state_dict本質上Python字典對象,所以可以很好地進行保存、更新、修改和恢復操作(python字典結構的特性),從而為PyTorch模型和優化器增加了大量的模塊化。
通過一個簡單的案例來輸出state_dict字典對象中存放的變量
#encoding:utf-8 import torch import torch.nn as nn import torch.optim as optim import torchvision import numpy as mp import matplotlib.pyplot as plt import torch.nn.functional as F #define model class TheModelClass(nn.Module): def __init__(self): super(TheModelClass,self).__init__() self.conv1=nn.Conv2d(3,6,5) self.pool=nn.MaxPool2d(2,2) self.conv2=nn.Conv2d(6,16,5) self.fc1=nn.Linear(16*5*5,120) self.fc2=nn.Linear(120,84) self.fc3=nn.Linear(84,10) def forward(self,x): x=self.pool(F.relu(self.conv1(x))) x=self.pool(F.relu(self.conv2(x))) x=x.view(-1,16*5*5) x=F.relu(self.fc1(x)) x=F.relu(self.fc2(x)) x=self.fc3(x) return x def main(): # Initialize model model = TheModelClass() #Initialize optimizer optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9) #print model's state_dict print('Model.state_dict:') for param_tensor in model.state_dict(): #打印 key value字典 print(param_tensor,'\t',model.state_dict()[param_tensor].size()) #print optimizer's state_dict print('Optimizer,s state_dict:') for var_name in optimizer.state_dict(): print(var_name,'\t',optimizer.state_dict()[var_name]) if __name__=='__main__': main()
具體的輸出結果如下:可以很清晰的觀測到state_dict中存放的key和value的值
Model.state_dict: conv1.weight torch.Size([6, 3, 5, 5]) conv1.bias torch.Size([6]) conv2.weight torch.Size([16, 6, 5, 5]) conv2.bias torch.Size([16]) fc1.weight torch.Size([120, 400]) fc1.bias torch.Size([120]) fc2.weight torch.Size([84, 120]) fc2.bias torch.Size([84]) fc3.weight torch.Size([10, 84]) fc3.bias torch.Size([10]) Optimizer,s state_dict: state {} param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [367949288, 367949432, 376459056, 381121808, 381121952, 381122024, 381121880, 381122168, 381122096, 381122312]}]
轉載自:https://blog.csdn.net/bigFatCat_Tom/article/details/90722261