神經網絡 參數計算--直接解析CKPT文件讀取


1、tensorflow的模型文件ckpt參數獲取

import tensoflow as tf
from tensorflow.python import pywrap_tensorflow
 
model_dir = "./ckpt/"
 
ckpt = tf.train.get_checkpoint_state(model_dir)
ckpt_path = ckpt.model_checkpoint_path
 
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
param_dict = reader.get_variable_to_shape_map()
 
for key, val in param_dict.items():
    try:
        print key, val
    except:

 

2、參數計算(求網絡模型大小)

from tensorflow.python import pywrap_tensorflow
import os
import numpy as np
model_dir = "models_pretrained/"
checkpoint_path = os.path.join(model_dir, "model.ckpt-82798")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
total_parameters = 0
for key in var_to_shape_map:#list the keys of the model
    # print(key)
    # print(reader.get_tensor(key))
    shape = np.shape(reader.get_tensor(key))  #get the shape of the tensor in the model
    shape = list(shape)
    # print(shape)
    # print(len(shape))
    variable_parameters = 1
    for dim in shape:
        # print(dim)
        variable_parameters *= dim
    # print(variable_parameters)
    total_parameters += variable_parameters

print(total_parameters)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM