上一篇我寫了如何給自己的圖像集制作tfrecords文件,現在我們就來講講如何讀取已經創建好的文件,我們使用的是Tensorflow中的Dataset來讀取我們的tfrecords,網上很多帖子應該是很久之前的了,絕大多數的做法是,先將tfrecords序列化成一個隊列,然后使用TFRecordReader這個函數進行解析,解析出來的每一行都是一個record,然后再將每一個record進行還原,但是這個函數你在使用的時候會報出異常,原因就是它已經被dataset中新的讀取方式所替代,下個版本中可能就無法使用了,因此不建議大家使用這個函數,好了,下面就來看看是如何進行讀取的吧。
1 import tensorflow as tf 2 import matplotlib.pyplot as plt 3 4 #定義可以一次獲得多張圖像的函數 5 def show_image(image_dir): 6 plt.imshow(image_dir) 7 plt.axis('on') 8 plt.show() 9 10 #單個record的解析函數 11 def decode_example(example):#,resize_height,resize_width,labels_nums): 12 features=tf.io.parse_single_example(example,features={ 13 'image_raw':tf.io.FixedLenFeature([],tf.string), 14 'label':tf.io.FixedLenFeature([],tf.int64) 15 }) 16 tf_image=tf.decode_raw(features['image_raw'],tf.uint8)#這個其實就是圖像的像素模式,之前我們使用矩陣來表示圖像 17 tf_image=tf.reshape(tf_image,shape=[224,224,3])#對圖像的尺寸進行調整,調整成三通道圖像 18 tf_image=tf.cast(tf_image,tf.float32)*(1./255)#對圖像進行歸一化以便保持和原圖像有相同的精度 19 tf_label=tf.cast(features['label'],tf.int32) 20 tf_label=tf.one_hot(tf_label,5,on_value=1,off_value=0)#將label轉化成用one_hot編碼的格式 21 return tf_image,tf_label 22 23 def batch_test(tfrecords_file): 24 dataset=tf.data.TFRecordDataset(tfrecords_file) 25 dataset=dataset.map(decode_example) 26 dataset=dataset.shuffle(100).batch(4) 27 iterator=tf.compat.v1.data.make_one_shot_iterator(dataset) 28 batch_images,batch_labels=iterator.get_next() 29 30 init_op=tf.compat.v1.global_variables_initializer() 31 with tf.compat.v1.Session() as sess: 32 sess.run(init_op) 33 coord=tf.train.Coordinator() 34 threads=tf.train.start_queue_runners(coord=coord) 35 for i in range(4): 36 images,labels=sess.run([batch_images,batch_labels]) 37 show_image(images[1,:,:,:]) 38 print('shape:{},tpye:{},labels:{}'.format(images.shape, images.dtype, labels)) 39 40 coord.request_stop() 41 coord.join(threads) 42 43 if __name__=='__main__': 44 tfrecords_file='D:/軟件/pycharmProject/wenyuPy/Dataset/VGG16/record/train.tfrecords' 45 resize_height=224 46 resize_width=224 47 batch_test(tfrecords_file)
我為了測試,寫了batch_test這個函數,因為我想試一試看我做的tfrecords能不能被解析成功,如果你不想測試只想訓練,那你直接把images_batch,和labels_batch放到網絡中進行訓練就可以了,還有一點要注意的,tf.global_variables_initializer()已經被tf.compat.v1.global_variables_initializer()所取代了,我做的時候不知道所以報了一個warning提示,同時tf.Sesssion()已經被tf.compat.v1.Session() 所替代,iterator=dataset.make_one_shot_iterator()已經被tf.compat.v1.data.make_one_shot_iterator(dataset) 所代替,這些異常要注意,然后我只是將每個batch的第二張圖片顯示出來了,你也可以顯示其他的,但是意義不大,反正只是測試一下解析成功與否,成功了我們就不需要糾結別的了。好啦,就是這樣,接下來我會把這些東西放到網絡中進行訓練,再更新我的學習,就醬。