如何手動解析Keras等框架保存的HDF5格式的權重文件


問題

在使用Keras保存為h5或者hdf5格式的模型權重文件后,一般采用keras.models.load_model()恢復模型結構和權重,或者采用model.load_weights()導入權重。
但是,在進行遷移學習或者模型輸出nan或inf時,需要手動導入部分權重進行查看或者修改之類,就不得不學會操作HDF5格式文件了。

解答

在Python中,常采用h5py庫對HDF5文件進行讀寫操作,這是非常方便的,其導出的權重都是numpy矩陣,可以直接應用。
HDF5格式本身類似於XML或者JSON,是一種通用的樹狀結構文檔的表示方式,通過 $ /根目錄/子目錄 $ 的路徑形式定位元素。
話不多說,上示例代碼:

# test.py
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

import cv2 as cv
import numpy as np
import glob

from model import build_encoder_decoder, build_refinement
from data_generator import normalize_input, denormalize_output, depth_random_scale_shift

import h5py as h5
import tensorflow as tf

if __name__ == '__main__':
    img_rows, img_cols = 288, 384
    channel = 4

    model_path = '../Models/DIM/final.00000000-0.0607.hdf5'
    coarse = build_encoder_decoder(img_rows, img_cols, train=False)
    fine = build_refinement(coarse, train=False)
    #fine.summary()

    f = h5.File(model_path, 'r')
    w = f['/model_weights']
    params = dict()
    for k in w.keys():
        if len(w[k].keys()) > 0:
            sub_keys = w[k].keys()
            for k2 in sub_keys:
                assert k == k2
                for k3 in w[k][k2].keys():
                    name_ = k + '/' + k3
                    val_ = w[k][k2][k3][()]
                    params[name_] = val_
                    if np.any(np.isnan(val_)):
                        print('NAN VAR FOUND: ')
                        print(name_)
                        #val_[np.where(np.isnan(val_))] = 0.0
                    if np.any(np.isinf(val_)):
                        print('INF VAR FOUND: ')
                        print(name_)
                        #val_[np.where(np.isnan(val_))] = 0.0

    #print(params)
    vars = tf.trainable_variables()
    assert len(params) == len(vars)

    sess = tf.Session()

    t_deconv6 = sess.graph.get_tensor_by_name('deconv6/Relu:0')

    for i in range(len(vars)):
        sess.run(vars[i].assign(params[vars[i].name]))

    # check all uninitialized variables
    var_unset = tf.report_uninitialized_variables(tf.global_variables())
    print(sess.run(var_unset))
    #fine.load_weights(model_path)
    t_out = fine.outputs[0]
    t_in = fine.inputs[0]

    files_rgbd = glob.glob('../Datasets/DIM/test/rgbd-1/*.PNG')
    #files_gt = glob.glob('../Datasets/DIM/test/gt/*.PNG')
    #assert len(files_gt) == len(files_rgbd)
    x_test = np.zeros([1, img_rows, img_cols, 4], dtype=np.float32)

    for i in range(len(files_rgbd)):
        print(files_rgbd[i])
        rgbd = cv.imread(files_rgbd[i], -1)
        rgbd = cv.resize(rgbd, (img_cols, img_rows))

        # alpha = cv.imread(files_gt[i], 0)
        # alpha = cv.resize(alpha, (img_cols, img_rows))
        #rgbd[:,:,3], alpha = depth_random_scale_shift(rgbd[:,:,3], alpha, 255)

        rgb = rgbd[:,:,:3]
        disp = np.stack([rgbd[:,:,3]]*3, axis=-1)
        #gt = np.stack([alpha]*3, axis=-1)

        x_test[0, :, :, :4] = normalize_input(rgbd)
        out, deconv6 = sess.run([t_out, t_deconv6], feed_dict={t_in: x_test})
        #print(deconv6)
        #print(out)
        #exit(0)

        out = denormalize_output(out[0,:,:,0])
        out = np.stack([out] * 3, axis=-1)
        #merged = np.concatenate((rgb, disp, out, gt), axis=1)
        merged = np.concatenate((rgb, disp, out), axis=1)
        cv.imwrite(files_rgbd[i].replace('rgbd-1', 'out'), merged)

上述代碼展示的功能就是從權重文件中導入所有的權重,並且依次檢查每個權重的正確性(合法性),是否包含INF或者NAN這類無效值,發現時,打印對應的名稱和其值。
同時上述代碼也給出了結合Tensorflow,手動利用解析出來的變量-權重字典,對計算圖進行變量初始化。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM