問題
在使用Keras保存為h5或者hdf5格式的模型權重文件后,一般采用keras.models.load_model()恢復模型結構和權重,或者采用model.load_weights()導入權重。
但是,在進行遷移學習或者模型輸出nan或inf時,需要手動導入部分權重進行查看或者修改之類,就不得不學會操作HDF5格式文件了。
解答
在Python中,常采用h5py庫對HDF5文件進行讀寫操作,這是非常方便的,其導出的權重都是numpy矩陣,可以直接應用。
HDF5格式本身類似於XML或者JSON,是一種通用的樹狀結構文檔的表示方式,通過 $ /根目錄/子目錄 $ 的路徑形式定位元素。
話不多說,上示例代碼:
# test.py
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import cv2 as cv
import numpy as np
import glob
from model import build_encoder_decoder, build_refinement
from data_generator import normalize_input, denormalize_output, depth_random_scale_shift
import h5py as h5
import tensorflow as tf
if __name__ == '__main__':
img_rows, img_cols = 288, 384
channel = 4
model_path = '../Models/DIM/final.00000000-0.0607.hdf5'
coarse = build_encoder_decoder(img_rows, img_cols, train=False)
fine = build_refinement(coarse, train=False)
#fine.summary()
f = h5.File(model_path, 'r')
w = f['/model_weights']
params = dict()
for k in w.keys():
if len(w[k].keys()) > 0:
sub_keys = w[k].keys()
for k2 in sub_keys:
assert k == k2
for k3 in w[k][k2].keys():
name_ = k + '/' + k3
val_ = w[k][k2][k3][()]
params[name_] = val_
if np.any(np.isnan(val_)):
print('NAN VAR FOUND: ')
print(name_)
#val_[np.where(np.isnan(val_))] = 0.0
if np.any(np.isinf(val_)):
print('INF VAR FOUND: ')
print(name_)
#val_[np.where(np.isnan(val_))] = 0.0
#print(params)
vars = tf.trainable_variables()
assert len(params) == len(vars)
sess = tf.Session()
t_deconv6 = sess.graph.get_tensor_by_name('deconv6/Relu:0')
for i in range(len(vars)):
sess.run(vars[i].assign(params[vars[i].name]))
# check all uninitialized variables
var_unset = tf.report_uninitialized_variables(tf.global_variables())
print(sess.run(var_unset))
#fine.load_weights(model_path)
t_out = fine.outputs[0]
t_in = fine.inputs[0]
files_rgbd = glob.glob('../Datasets/DIM/test/rgbd-1/*.PNG')
#files_gt = glob.glob('../Datasets/DIM/test/gt/*.PNG')
#assert len(files_gt) == len(files_rgbd)
x_test = np.zeros([1, img_rows, img_cols, 4], dtype=np.float32)
for i in range(len(files_rgbd)):
print(files_rgbd[i])
rgbd = cv.imread(files_rgbd[i], -1)
rgbd = cv.resize(rgbd, (img_cols, img_rows))
# alpha = cv.imread(files_gt[i], 0)
# alpha = cv.resize(alpha, (img_cols, img_rows))
#rgbd[:,:,3], alpha = depth_random_scale_shift(rgbd[:,:,3], alpha, 255)
rgb = rgbd[:,:,:3]
disp = np.stack([rgbd[:,:,3]]*3, axis=-1)
#gt = np.stack([alpha]*3, axis=-1)
x_test[0, :, :, :4] = normalize_input(rgbd)
out, deconv6 = sess.run([t_out, t_deconv6], feed_dict={t_in: x_test})
#print(deconv6)
#print(out)
#exit(0)
out = denormalize_output(out[0,:,:,0])
out = np.stack([out] * 3, axis=-1)
#merged = np.concatenate((rgb, disp, out, gt), axis=1)
merged = np.concatenate((rgb, disp, out), axis=1)
cv.imwrite(files_rgbd[i].replace('rgbd-1', 'out'), merged)
上述代碼展示的功能就是從權重文件中導入所有的權重,並且依次檢查每個權重的正確性(合法性),是否包含INF或者NAN這類無效值,發現時,打印對應的名稱和其值。
同時上述代碼也給出了結合Tensorflow,手動利用解析出來的變量-權重字典,對計算圖進行變量初始化。