基於pytorch神經網絡模型參數的加載及自定義


最近在訓練MobileNet時經常會對其模型參數進行各種操作,或者替換其中的幾層之類的,故總結一下用到的對神經網絡參數的各種操作方法。

1.將matlab的.mat格式參數整理轉換為tensor類型的模型參數

import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.io as scio
import os
import numpy as np
from config import Config
import json
config = Config()

Mul = Config.MUL.astype('float32')
Shift = Config.SHIFT.astype('float32')

def load_json(j_fn):
    with open(j_fn,'r') as f:
        data = json.load(f)
    return data

def save_json(dic,j_fn):
    json_str = json.dumps(dic)
    with open(j_fn,'w') as json_file:
        json_file.write(json_str)

w_dic = {}
b_dic = {}
dic_all = {}
for i in range(1,28,2):
    a = 'w'+str(i)    #按順序命名
    b = 'b'+str(i)
    dic_all[a] = torch.from_numpy(scio.loadmat(config.WEIGHT_PATH + str(i)+'.mat')['wei'] * Mul[i-1]/(2**Shift[i-1])).permute(3, 2, 0, 1)
    dic_all[b] = torch.squeeze(torch.from_numpy(scio.loadmat(config.BIAS_PATH + str(i)+'.mat')['bias'] * Mul[i-1]/(2**Shift[i-1])))
    # print(a, 'Mul'+str(i-1))
    if i == 27:
        break
    a = 'w'+str(i+1)
    b = 'b'+str(i+1)
    dic_all[a] = torch.from_numpy(scio.loadmat(config.WEIGHT_PATH + str(i+1)+'.mat')['wei'] * Mul[i]/(2**Shift[i])).permute(2, 0, 1).unsqueeze(1)
    dic_all[b] = torch.squeeze(torch.from_numpy(scio.loadmat(config.BIAS_PATH + str(i+1)+'.mat')['bias'] * Mul[i]/(2**Shift[i])))
#此處由於自己之前的命名問題,中間跳過了28層(池化層),直接按照有參數的層存儲了參數,故27后的文件名變成了29
dic_all['w29'] = torch.squeeze(torch.from_numpy(scio.loadmat(config.WEIGHT_PATH + '29.mat')['wei'] * Mul[28]/(2**Shift[28])).permute(3, 2, 0, 1)[1:, :])
dic_all['b29'] = torch.squeeze(torch.from_numpy(scio.loadmat(config.BIAS_PATH + '29.mat')['bias'] * Mul[28]/(2**Shift[28])))[1:]
#存為.pth文件
param_fn = 'mobilenet_param_float.pth'
torch.save(dic_all,param_fn)

其中,mul和shift為量化后的乘子和移位參數(如果參數是浮點的則可以忽略這部分),另外,我的weight和bias是按照每層單獨存在一個按照層序號命名的.mat文件中。且由於是從matlab的程序得到的,需要對參數的維度進行一下轉換(permute()方法),同時對需要增加或減少維度的用unsqueeze()或torch.squeeze()方法進行改變(注意一定要和網絡需要的輸入維度相同才行)。最后按照原來對參數文件命名的方式保存成字典存成.pth文件(此時的字典還不能直接使用,需要在具體定義的網絡中更換想應的key值)。

*另外,代碼中用來讀取和存儲.json文件的函數可以忽略,在這里沒有用到

2.將自定義網絡的參數替換成自己需要的(DIY模型參數)

from mobilenet_v1 import MobileNet_v1
import torch
from config import Config
from load_data import loadtestdata
from torch.autograd import Variable
import numpy as np
from Mobilenetv1_quantified import MobileNet, MobileNet_Bayer
import json
import matplotlib.pyplot as plt
import numpy as np
import torchvision

param_keys = ['w1', 'b1', 'w2', 'b2', 'w3', 'b3', 'w4', 'b4', 'w5', 'b5', 'w6', 'b6', 'w7', 'b7', 'w8', 'b8', 'w9', 'b9', 'w10', 'b10', 'w11', 'b11', 'w12', 'b12', 'w13', 'b13', 'w14', 'b14', 'w15', 'b15', 'w16', 'b16', 'w17', 'b17', 'w18', 'b18', 'w19', 'b19', 'w20', 'b20', 'w21', 'b21', 'w22', 'b22', 'w23', 'b23', 'w24', 'b24', 'w25', 'b25', 'w26', 'b26', 'w27', 'b27', 'w29', 'b29']
file_name = '/home/wangshuyu/MobileNet_v1/mobilenet_param_float.pth'
dic_param = torch.load(file_name)      # 此處打開上一步存成的參數字典(按照每一層的權重、偏置的順序)
Model = MobileNet()                    # 實例化預定義的MobileNet網絡(網絡結構將在其他文中給出)
net_dic = Model.state_dict()           # 加載預定義網絡的參數字典,用來獲取網絡的鍵值
for i, param_tensor in enumerate(net_dic ,0):
    net_dic[param_tensor] = dic_param[param_keys[i]]
    # print(i,'\t',param_tensor ,net_dic[param_tensor].shape)   #可以用來查看參數的維度
param_fn = 'MobileNet_float.pth'
torch.save(net_dic,param_fn)
# 下面開始是自己定義的另一個網絡,只需要固定MobileNet其中一部分參數,剩下的部分參數用來訓練,因此只從第11個之后的開始取參數
model2 = MobileNet_Bayer()
dic2 = model2.state_dict()
key_list = list(net_dic.keys())
for i, param_tensor in enumerate(dic2 ,0):
    if i > 11:
        dic2[param_tensor] = (net_dic[key_list[i - 2]])
    print(i, '\t', param_tensor, dic2[param_tensor].shape)
param_fn2 = 'MobileNet_Bayer.pth'
torch.save(dic2,param_fn2)

這里主要實現了將之前存好的量化后的mobilenet每層參數根據自己定義的網絡構建了參數字典,在訓練或測試的時候,只需要加載之前存好的預訓練參數就可以了:

from Mobilenetv1_quantified import MobileNet
import torch
from load_data import loadtestdata Net
= MobileNet() param_dic = torch.load('MobileNet_float.pth') Net.load_state_dict(param_dic)
classes = range(0,1000)
test_data = loadtestdate()
test(test_data, Net, classes)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM