Tensorflow的MobileNetV1參數遷移到pytorch上並保存


因為放棄tensorflow超級久了,也不想再去用它,因為明明很簡單用pytorch十幾行作出的代碼,tensorflow的版本完全看不懂,我這個菜雞還是老老實實刨地吧。mobilenet的代碼網上一大堆,我把我寫的貼出來吧,論文簡單易讀,連我這種英語渣渣兩天就看完了。

mobelnet的代碼如下。

import torch.nn as nn
import torch
class Conv_bn(nn.Module):
    def __init__(self,inp,oup,stride):
        super(Conv_bn, self).__init__()
        self.convBn=nn.Sequential(
            nn.Conv2d(inp,oup,3,stride,1,bias=False),
            nn.BatchNorm2d(oup),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        out=self.convBn(x)
        return out

class Conv_depth(nn.Module):
    def __init__(self,inp,oup,stride):
        super(Conv_depth, self).__init__()
        self.convDepthwise=nn.Sequential(
            nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
            nn.BatchNorm2d(inp),
            nn.ReLU(inplace=True),

            nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        out=self.convDepthwise(x)
        return out


class MobileNet(nn.Module):
    def __init__(self):
        super(MobileNet, self).__init__()
        self.mobelnet=nn.Sequential(
            Conv_bn(3, 32, 2),
            Conv_depth(32, 64, 1),
            Conv_depth(64, 128, 2),
            Conv_depth(128, 128, 1),
            Conv_depth(128, 256, 2),
            Conv_depth(256, 256, 1),
            Conv_depth(256, 512, 2),
            Conv_depth(512, 512, 1),
            Conv_depth(512, 512, 1),
            Conv_depth(512, 512, 1),
            Conv_depth(512, 512, 1),
            Conv_depth(512, 512, 1),
            Conv_depth(512, 1024, 2),
            Conv_depth(1024, 1024, 1),
            nn.AvgPool2d(7),)

        self.fc = nn.Linear(1024, 1000)

    # 網絡的前向過程
    def forward(self, x):
        x=self.mobelnet(x)
        x=x.view(-1, 1024)
        x=self.fc(x)
        return x

媽呀,簡單吧,但是你不知道tensorflow的版本有多長啊。

然后轉參數把我難住了,沒做過,參考了 https://www.jianshu.com/p/0a61caeb693b 這位同學的moielnetV3版本的改法,但是我真的不懂他那個字典怎么定義的,我每次model.層名 就開始給我出紅杠杠,報錯,我估計可能是他把層都封裝成了對象吧,如果有懂的同學希望能給我講講哈。我貼我自己的代碼吧。

import json
import tensorflow as tf
import os
from MobileNet.mobilenet_v1 import MobileNet
import numpy as np
import torch
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
CHECKPOINT_PATH='/Users/wenyu/Desktop/TorchProject/MobileNet/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt'

# write the json file
def new_dict(checkpoint_path,json_path):
    reader=tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
    weights_shape =reader.get_variable_to_shape_map()
    print('the layer',weights_shape['MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean'])
    length=len(weights_shape['MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean'])
    # print(length)
    if not os.path.exists(json_path):
        weights_small = {n: 1 for (n, _) in reader.get_variable_to_shape_map().items()}
        keys_list=list(weights_small.keys())
        for key_ in keys_list:
            if "/ExponentialMovingAverage" in key_:
                del weights_small[key_]
            elif "/RMSProp" in key_:
                del weights_small[key_]
        with open(json_path, 'w') as writer:
            json.dump(weights_small, fp=writer, sort_keys=True)
    else:
        print('the json file has been write!')

# get convBn_dict
def get_convbn_convert_dict(layer_num):
    convert_dict={
        'mobelnet.'+str(layer_num)+'.convBn.0.weight':'MobilenetV1/Conv2d_'+str(layer_num)+'/weights',
        'mobelnet.'+str(layer_num)+'.convBn.1.weight':'MobilenetV1/Conv2d_'+str(layer_num)+'/BatchNorm/beta',
        'mobelnet.'+str(layer_num)+'.convBn.1.bias':'MobilenetV1/Conv2d_'+str(layer_num)+'/BatchNorm/gamma',
        'mobelnet.'+str(layer_num)+'.convBn.1.running_mean':'MobilenetV1/Conv2d_'+str(layer_num)+'/BatchNorm/moving_mean',
        'mobelnet.'+str(layer_num)+'.convBn.1.running_var':'MobilenetV1/Conv2d_'+str(layer_num)+'/BatchNorm/moving_variance'
    }
    return convert_dict

# get depthWise_dict
def get_dpwise_convert_dict(layer_num):
    convert_dict={
        'mobelnet.'+str(layer_num)+'.convDepthwise.0.weight':
            'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/depthwise_weights',
        'mobelnet.'+str(layer_num)+'.convDepthwise.1.weight':
            'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/BatchNorm/beta',
        'mobelnet.'+str(layer_num)+'.convDepthwise.1.bias':
            'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/BatchNorm/gamma',
        'mobelnet.'+str(layer_num)+'.convDepthwise.1.running_mean':
            'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/BatchNorm/moving_mean',
        'mobelnet.'+str(layer_num)+'.convDepthwise.1.running_var':
            'MobilenetV1/Conv2d_'+str(layer_num)+'_depthwise/BatchNorm/moving_variance',
        'mobelnet.'+str(layer_num)+'.convDepthwise.3.weight':
            'MobilenetV1/Conv2d_'+str(layer_num)+'_pointwise/weights',
        'mobelnet.'+str(layer_num)+'.convDepthwise.4.weight':
            'MobilenetV1/Conv2d_'+str(layer_num)+'_pointwise/BatchNorm/beta',
        'mobelnet.' + str(layer_num) + '.convDepthwise.4.bias':
            'MobilenetV1/Conv2d_' + str(layer_num) + '_pointwise/BatchNorm/gamma',
        'mobelnet.' + str(layer_num) + '.convDepthwise.4.running_mean':
            'MobilenetV1/Conv2d_' + str(layer_num) + '_pointwise/BatchNorm/moving_mean',
        'mobelnet.' + str(layer_num) + '.convDepthwise.4.running_var':
            'MobilenetV1/Conv2d_' + str(layer_num) + '_pointwise/BatchNorm/moving_variance'
    }
    return convert_dict

# get conversion_dict
def get_model_dict(layers_num):
    merge = lambda dict1, dict2: {**dict1, **dict2}
    conversion_table = {}
    convBn_dict=get_convbn_convert_dict(0)
    conversion_table=merge(conversion_table,convBn_dict)
    for i in range(1,layers_num):
        dpWise_dict=get_dpwise_convert_dict(i)
        conversion_table=merge(conversion_table,dpWise_dict)
    # load_parameter(CHECKPOINT_PATH,conversion_table)
    return conversion_table
def write_json(conversion_table,json_path):
    if not os.path.exists(json_path):
        with open(json_path, 'w') as writer:
            json.dump(conversion_table, fp=writer, sort_keys=True)
    else:
        print('the conversion table has been wirten!')

def load_parameter(conversion_table):
    module=MobileNet()
    original_model_dict=module.state_dict()
    pth_list=list(conversion_table.keys())
    ckpt_list=list(conversion_table.values())
    assert len(pth_list)==len(ckpt_list) ,('the length is not right!')
    reader=tf.compat.v1.train.NewCheckpointReader(CHECKPOINT_PATH)
    for i,ckpt_name in enumerate(ckpt_list):
        ckpt_name_value=tf.compat.v1.train.load_variable(CHECKPOINT_PATH,ckpt_name)
        if 'Conv2d' in ckpt_name and 'weights' in ckpt_name:
            ckpt_name_value=np.transpose(ckpt_name_value,(3,2,0,1))
            if 'depthwise' in ckpt_name:
                ckpt_name_value=np.transpose(ckpt_name_value,(1,0,2,3))
        elif 'BatchNorm' in ckpt_name and ckpt_name_value.ndim==1:
            # ckpt_name_value=np.transpose(ckpt_name_value)
            ckpt_name_value=ckpt_name_value
        pytorch_dict_key=pth_list[i]
        original_model_dict[pytorch_dict_key].data=torch.from_numpy(ckpt_name_value)

    torch.save(original_model_dict,'/Users/wenyu/Desktop/TorchProject/MobileNet/tf_to_torch.pth')
    return original_model_dict

if __name__ == '__main__':
    conversion_table=get_model_dict(14)
    dic_mobel=load_parameter(conversion_table)
    print(dic_mobel['mobelnet.1.convDepthwise.0.weight'].shape)

其中核心就在最后兩個函數,可能代碼看起來很簡單,但是我想了好久要怎么做,第一次做很不熟練,但是通過這次鞏固了很多numpy,tensor還有字典的基本知識,很充實。有問題可以在博客下面留言。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM