Tensorflow讀取數據的一般方式有下面3種:
- preloaded直接創建變量:在tensorflow定義圖的過程中,創建常量或變量來存儲數據
- feed:在運行程序時,通過feed_dict傳入數據
- reader從文件中讀取數據:在tensorflow圖開始時,通過一個輸入管線從文件中讀取數據
Preloaded方法的簡單例子
1 import tensorflow as tf 2 3 """定義常量""" 4 const_var = tf.constant([1, 2, 3]) 5 """定義變量""" 6 var = tf.Variable([1, 2, 3]) 7 8 with tf.Session() as sess: 9 sess.run(tf.global_variables_initializer()) 10 print(sess.run(var)) 11 print(sess.run(const_var))
Feed方法
可以在tensorflow運算圖的過程中,將數據傳遞到事先定義好的placeholder中。方法是在調用session.run函數時,通過feed_dict參數傳入。簡單例子:
1 import tensorflow as tf 2 """定義placeholder""" 3 x1 = tf.placeholder(tf.int16) 4 x2 = tf.placeholder(tf.int16) 5 result = x1 + x2 6 """定義feed_dict""" 7 feed_dict = { 8 x1: [10], 9 x2: [20] 10 } 11 """運行圖""" 12 with tf.Session() as sess: 13 print(sess.run(result, feed_dict=feed_dict))
上面的兩個方法在面對大量數據時,都存在性能問題。這時候就需要使用到第3種方法,文件讀取,讓tensorflow自己從文件中讀取數據
從文件中讀取數據
圖引用自 https://zhuanlan.zhihu.com/p/27238630
步驟:
- 獲取文件名列表list
- 創建文件名隊列,調用tf.train.string_input_producer,參數包含:文件名列表,num_epochs【定義重復次數】,shuffle【定義是否打亂文件的順序】
- 定義對應文件的閱讀器>* tf.ReaderBase >* tf.TFRecordReader >* tf.TextLineReader >* tf.WholeFileReader >* tf.IdentityReader >* tf.FixedLengthRecordReader
- 解析器 >* tf.decode_csv >* tf.decode_raw >* tf.image.decode_image >* …
- 預處理,對原始數據進行處理,以適應network輸入所需
- 生成batch,調用tf.train.batch() 或者 tf.train.shuffle_batch()
- prefetch【可選】使用預加載隊列slim.prefetch_queue.prefetch_queue()
- 啟動填充隊列的線程,調用tf.train.start_queue_runners

圖引用自http://www.yyliu.cn/post/89458415.html
讀取文件格式舉例
tensorflow支持讀取的文件格式包括:CSV文件,二進制文件,TFRecords文件,圖像文件,文本文件等等。具體使用時,需要根據文件的不同格式,選擇對應的文件格式閱讀器,再將文件名隊列傳為參數,傳入閱讀器的read方法中。方法會返回key與對應的record value。將value交給解析器進行解析,轉換成網絡能進行處理的tensor。
CSV文件讀取:
閱讀器:tf.TextLineReader
解析器:tf.decode_csv
1 filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"]) 2 """閱讀器""" 3 reader = tf.TextLineReader() 4 key, value = reader.read(filename_queue) 5 """解析器""" 6 record_defaults = [[1], [1], [1], [1]] 7 col1, col2, col3, col4 = tf.decode_csv(value, record_defaults=record_defaults) 8 features = tf.concat([col1, col2, col3, col4], axis=0) 9 10 with tf.Session() as sess: 11 coord = tf.train.Coordinator() 12 threads = tf.train.start_queue_runners(coord=coord) 13 for i in range(100): 14 example = sess.run(features) 15 coord.request_stop() 16 coord.join(threads)
二進制文件讀取:
閱讀器:tf.FixedLengthRecordReader
解析器:tf.decode_raw
圖像文件讀取:
閱讀器:tf.WholeFileReader
解析器:tf.image.decode_image, tf.image.decode_gif, tf.image.decode_jpeg, tf.image.decode_png
TFRecords文件讀取
TFRecords文件是tensorflow的標准格式。要使用TFRecords文件讀取,事先需要將數據轉換成TFRecords文件,具體可察看:convert_to_records.py 在這個腳本中,先將數據填充到tf.train.Example協議內存塊(protocol buffer),將協議內存塊序列化為字符串,再通過tf.python_io.TFRecordWriter寫入到TFRecords文件中去。
閱讀器:tf.TFRecordReader
解析器:tf.parse_single_example
又或者使用slim提供的簡便方法:slim.dataset.Data以及slim.dataset_data_provider.DatasetDataProvider方法
1 def get_split(record_file_name, num_sampels, size): 2 reader = tf.TFRecordReader 3 4 keys_to_features = { 5 "image/encoded": tf.FixedLenFeature((), tf.string, ''), 6 "image/format": tf.FixedLenFeature((), tf.string, 'jpeg'), 7 "image/height": tf.FixedLenFeature([], tf.int64, tf.zeros([], tf.int64)), 8 "image/width": tf.FixedLenFeature([], tf.int64, tf.zeros([], tf.int64)), 9 } 10 11 items_to_handlers = { 12 "image": slim.tfexample_decoder.Image(shape=[size, size, 3]), 13 "height": slim.tfexample_decoder.Tensor("image/height"), 14 "width": slim.tfexample_decoder.Tensor("image/width"), 15 } 16 17 decoder = slim.tfexample_decoder.TFExampleDecoder( 18 keys_to_features, items_to_handlers 19 ) 20 return slim.dataset.Dataset( 21 data_sources=record_file_name, 22 reader=reader, 23 decoder=decoder, 24 items_to_descriptions={}, 25 num_samples=num_sampels 26 ) 27 28 29 def get_image(num_samples, resize, record_file="image.tfrecord", shuffle=False): 30 provider = slim.dataset_data_provider.DatasetDataProvider( 31 get_split(record_file, num_samples, resize), 32 shuffle=shuffle 33 ) 34 [data_image] = provider.get(["image"]) 35 return data_image
參考資料:
tensorflow 1.0 學習:十圖詳解tensorflow數據讀取機制
