如何手动解析Keras等框架保存的HDF5格式的权重文件


问题

在使用Keras保存为h5或者hdf5格式的模型权重文件后,一般采用keras.models.load_model()恢复模型结构和权重,或者采用model.load_weights()导入权重。
但是,在进行迁移学习或者模型输出nan或inf时,需要手动导入部分权重进行查看或者修改之类,就不得不学会操作HDF5格式文件了。

解答

在Python中,常采用h5py库对HDF5文件进行读写操作,这是非常方便的,其导出的权重都是numpy矩阵,可以直接应用。
HDF5格式本身类似于XML或者JSON,是一种通用的树状结构文档的表示方式,通过 $ /根目录/子目录 $ 的路径形式定位元素。
话不多说,上示例代码:

# test.py
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

import cv2 as cv
import numpy as np
import glob

from model import build_encoder_decoder, build_refinement
from data_generator import normalize_input, denormalize_output, depth_random_scale_shift

import h5py as h5
import tensorflow as tf

if __name__ == '__main__':
    img_rows, img_cols = 288, 384
    channel = 4

    model_path = '../Models/DIM/final.00000000-0.0607.hdf5'
    coarse = build_encoder_decoder(img_rows, img_cols, train=False)
    fine = build_refinement(coarse, train=False)
    #fine.summary()

    f = h5.File(model_path, 'r')
    w = f['/model_weights']
    params = dict()
    for k in w.keys():
        if len(w[k].keys()) > 0:
            sub_keys = w[k].keys()
            for k2 in sub_keys:
                assert k == k2
                for k3 in w[k][k2].keys():
                    name_ = k + '/' + k3
                    val_ = w[k][k2][k3][()]
                    params[name_] = val_
                    if np.any(np.isnan(val_)):
                        print('NAN VAR FOUND: ')
                        print(name_)
                        #val_[np.where(np.isnan(val_))] = 0.0
                    if np.any(np.isinf(val_)):
                        print('INF VAR FOUND: ')
                        print(name_)
                        #val_[np.where(np.isnan(val_))] = 0.0

    #print(params)
    vars = tf.trainable_variables()
    assert len(params) == len(vars)

    sess = tf.Session()

    t_deconv6 = sess.graph.get_tensor_by_name('deconv6/Relu:0')

    for i in range(len(vars)):
        sess.run(vars[i].assign(params[vars[i].name]))

    # check all uninitialized variables
    var_unset = tf.report_uninitialized_variables(tf.global_variables())
    print(sess.run(var_unset))
    #fine.load_weights(model_path)
    t_out = fine.outputs[0]
    t_in = fine.inputs[0]

    files_rgbd = glob.glob('../Datasets/DIM/test/rgbd-1/*.PNG')
    #files_gt = glob.glob('../Datasets/DIM/test/gt/*.PNG')
    #assert len(files_gt) == len(files_rgbd)
    x_test = np.zeros([1, img_rows, img_cols, 4], dtype=np.float32)

    for i in range(len(files_rgbd)):
        print(files_rgbd[i])
        rgbd = cv.imread(files_rgbd[i], -1)
        rgbd = cv.resize(rgbd, (img_cols, img_rows))

        # alpha = cv.imread(files_gt[i], 0)
        # alpha = cv.resize(alpha, (img_cols, img_rows))
        #rgbd[:,:,3], alpha = depth_random_scale_shift(rgbd[:,:,3], alpha, 255)

        rgb = rgbd[:,:,:3]
        disp = np.stack([rgbd[:,:,3]]*3, axis=-1)
        #gt = np.stack([alpha]*3, axis=-1)

        x_test[0, :, :, :4] = normalize_input(rgbd)
        out, deconv6 = sess.run([t_out, t_deconv6], feed_dict={t_in: x_test})
        #print(deconv6)
        #print(out)
        #exit(0)

        out = denormalize_output(out[0,:,:,0])
        out = np.stack([out] * 3, axis=-1)
        #merged = np.concatenate((rgb, disp, out, gt), axis=1)
        merged = np.concatenate((rgb, disp, out), axis=1)
        cv.imwrite(files_rgbd[i].replace('rgbd-1', 'out'), merged)

上述代码展示的功能就是从权重文件中导入所有的权重,并且依次检查每个权重的正确性(合法性),是否包含INF或者NAN这类无效值,发现时,打印对应的名称和其值。
同时上述代码也给出了结合Tensorflow,手动利用解析出来的变量-权重字典,对计算图进行变量初始化。


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM