使用tensorflow中的Dataset來讀取制作好的tfrecords文件


上一篇我寫了如何給自己的圖像集制作tfrecords文件,現在我們就來講講如何讀取已經創建好的文件,我們使用的是Tensorflow中的Dataset來讀取我們的tfrecords,網上很多帖子應該是很久之前的了,絕大多數的做法是,先將tfrecords序列化成一個隊列,然后使用TFRecordReader這個函數進行解析,解析出來的每一行都是一個record,然后再將每一個record進行還原,但是這個函數你在使用的時候會報出異常,原因就是它已經被dataset中新的讀取方式所替代,下個版本中可能就無法使用了,因此不建議大家使用這個函數,好了,下面就來看看是如何進行讀取的吧。

 1 import tensorflow as tf  2 import matplotlib.pyplot as plt  3 
 4 #定義可以一次獲得多張圖像的函數
 5 def show_image(image_dir):  6  plt.imshow(image_dir)  7     plt.axis('on')  8  plt.show()  9 
10 #單個record的解析函數
11 def decode_example(example):#,resize_height,resize_width,labels_nums):
12     features=tf.io.parse_single_example(example,features={ 13         'image_raw':tf.io.FixedLenFeature([],tf.string), 14         'label':tf.io.FixedLenFeature([],tf.int64) 15  }) 16     tf_image=tf.decode_raw(features['image_raw'],tf.uint8)#這個其實就是圖像的像素模式,之前我們使用矩陣來表示圖像
17     tf_image=tf.reshape(tf_image,shape=[224,224,3])#對圖像的尺寸進行調整,調整成三通道圖像
18     tf_image=tf.cast(tf_image,tf.float32)*(1./255)#對圖像進行歸一化以便保持和原圖像有相同的精度
19     tf_label=tf.cast(features['label'],tf.int32) 20     tf_label=tf.one_hot(tf_label,5,on_value=1,off_value=0)#將label轉化成用one_hot編碼的格式
21     return tf_image,tf_label 22 
23 def batch_test(tfrecords_file): 24     dataset=tf.data.TFRecordDataset(tfrecords_file) 25     dataset=dataset.map(decode_example) 26     dataset=dataset.shuffle(100).batch(4) 27     iterator=tf.compat.v1.data.make_one_shot_iterator(dataset) 28     batch_images,batch_labels=iterator.get_next() 29 
30     init_op=tf.compat.v1.global_variables_initializer() 31  with tf.compat.v1.Session() as sess: 32  sess.run(init_op) 33         coord=tf.train.Coordinator() 34         threads=tf.train.start_queue_runners(coord=coord) 35         for i in range(4): 36             images,labels=sess.run([batch_images,batch_labels]) 37             show_image(images[1,:,:,:]) 38             print('shape:{},tpye:{},labels:{}'.format(images.shape, images.dtype, labels)) 39 
40  coord.request_stop() 41  coord.join(threads) 42 
43 if __name__=='__main__': 44     tfrecords_file='D:/軟件/pycharmProject/wenyuPy/Dataset/VGG16/record/train.tfrecords'
45     resize_height=224
46     resize_width=224
47     batch_test(tfrecords_file)

我為了測試,寫了batch_test這個函數,因為我想試一試看我做的tfrecords能不能被解析成功,如果你不想測試只想訓練,那你直接把images_batch,和labels_batch放到網絡中進行訓練就可以了,還有一點要注意的,tf.global_variables_initializer()已經被tf.compat.v1.global_variables_initializer()所取代了,我做的時候不知道所以報了一個warning提示,同時tf.Sesssion()已經被tf.compat.v1.Session() 所替代,iterator=dataset.make_one_shot_iterator()已經被tf.compat.v1.data.make_one_shot_iterator(dataset)  所代替,這些異常要注意,然后我只是將每個batch的第二張圖片顯示出來了,你也可以顯示其他的,但是意義不大,反正只是測試一下解析成功與否,成功了我們就不需要糾結別的了。好啦,就是這樣,接下來我會把這些東西放到網絡中進行訓練,再更新我的學習,就醬。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM