tensorflow在保存權重模型時多使用tf.train.Saver().save 函數進行權重保存,保存的ckpt文件無法直接打開,不利於將模型權重導入到其他框架使用(如Caffe、Keras等)。
好在tensorflow提供了相關函數 tf.train.NewCheckpointReader 可以對ckpt文件進行權重查看,因此可以通過該函數進行數據導出。
1 import tensorflow as tf 2 import h5py 3 4 cpktLogFileName = r'./checkpoint/checkpoint' #cpkt 文件路徑 5 with open(cpktLogFileName, 'r') as f: 6 #權重節點往往會保留多個epoch的數據,此處獲取最后的權重數據 7 cpktFileName = f.readline().split('"')[1] 8 9 h5FileName = r'./model/net_classification.h5' 10 11 reader = tf.train.NewCheckpointReader(cpktFileName) 12 f = h5py.File(h5FileName, 'w') 13 t_g = None 14 for key in sorted(reader.get_variable_to_shape_map()): 15 # 權重名稱需根據自己網絡名稱自行修改 16 if key.endswith('w') or key.endswith('biases'): 17 keySplits = key.split(r'/') 18 keyDict = keySplits[1] + '/' + keySplits[1] + '/' + keySplits[2] 19 f[keyDict] = reader.get_tensor(key)