先看下文件讀取以及讀取數據處理成張量結果的過程:
一般數據文件格式有文本、excel和圖片數據。那么TensorFlow都有對應的解析函數,除了這幾種。還有TensorFlow指定的文件格式。
TensorFlow還提供了一種內置文件格式TFRecord,二進制數據和訓練類別標簽數據存儲在同一文件。模型訓練前圖像等文本信息轉換為TFRecord格式。TFRecord文件是protobuf格式。數據不壓縮,可快速加載到內存。TFRecords文件包含 tf.train.Example protobuf,需要將Example填充到協議緩沖區,將協議緩沖區序列化為字符串,然后使用該文件將該字符串寫入TFRecords文件。在圖像操作我們會介紹整個過程以及詳細參數。
文件讀取
文件隊列構造
tf.train.string_input_producer(string_tensor, ,shuffle=True):將輸出字符串(例如文件名)輸入到管道隊列
- string_tensor:含有文件名的1階張量
- num_epochs:過幾遍數據,默認無限過數據
- return:具有輸出字符串的隊列
將文件名列表交給tf.train.string_input_producer函數。string_input_producer來生成一個先入先出的隊列,文件閱讀器會需要它們來取數據。string_input_producer提供的可配置參數來設置文件名亂序和最大的訓練迭代數,QueueRunner會為每次迭代(epoch)將所有的文件名加入文件名隊列中,如果shuffle=True的話,會對文件名進行亂序處理。一過程是比較均勻的,因此它可以產生均衡的文件名隊列。
這個QueueRunner工作線程是獨立於文件閱讀器的線程,因此亂序和將文件名推入到文件名隊列這些過程不會阻塞文件閱讀器運行。根據你的文件格式,選擇對應的文件閱讀器,然后將文件名隊列提供給閱讀器的read方法。閱讀器的read方法會輸出一個鍵來表征輸入的文件和其中紀錄(對於調試非常有用),同時得到一個字符串標量,這個字符串標量可以被一個或多個解析器,或者轉換操作將其解碼為張量並且構造成為樣本。
文件閱讀器
根據文件格式,選擇對應的文件閱讀器
class tf.TextLineReader:閱讀文本文件逗號分隔值(CSV)格式,默認按行讀取
class tf.FixedLengthRecordReader(record_bytes):要讀取每個記錄是固定數量字節的二進制文件
- record_bytes:整型,指定每次讀取的字節數
tf.TFRecordReader:讀取TfRecords文件
文件內容解碼器
由於從文件中讀取的是字符串,需要函數去解析這些字符串到張量
tf.decode_csv(records,record_defaults=None,field_delim = None,name = None):將CSV轉換為張量,與tf.TextLineReader搭配使用
- records:tensor型字符串,每個字符串是csv中的記錄行
- field_delim:默認分割符”,”
- record_defaults:參數決定了所得張量的類型,並設置一個值在輸入字符串中缺少使用默認值
tf.decode_raw(bytes,out_type,little_endian = None,name = None) :將字節轉換為一個數字向量表示,字節為一字符串類型的張量,與函數tf.FixedLengthRecordReader搭配使用,二進制讀取為uint8格式
開啟線程操作
tf.train.start_queue_runners(sess=None,coord=None):收集所有圖中的隊列線程,並啟動線程
- sess:所在的會話中
- coord:線程協調器
- return:返回所有線程隊列
管道讀端批處理
tf.train.batch(tensors,batch_size,num_threads = 1,capacity = 32,name=None):讀取指定大小(個數)的張量
- tensors:可以是包含張量的列表
- batch_size:從隊列中讀取的批處理大小
- num_threads:進入隊列的線程數
- capacity:整數,隊列中元素的最大數量
- return:tensors
tf.train.shuffle_batch(tensors,batch_size,capacity,min_after_dequeue, num_threads=1,) :亂序讀取指定大小(個數)的張量
- min_after_dequeue:留下隊列里的張量個數,能夠保持隨機打亂
CSV文件讀取案例
import tensorflow as tf
import os
def readcsv(filelist):
"""
讀取csv文件
"""
# 構造文件隊列
file_queue = tf.train.string_input_producer(filelist)
# 構建閱讀器
reader = tf.TextLineReader()
key, value = reader.read(file_queue)
# 對每行內容進行解碼
records = [["None"], ["None"]]
example, label = tf.decode_csv(value, record_defaults=records)
# 批處理
example_batch, label_batch = tf.train.batch([example, label], batch_size=10, num_threads=1, capacity=10)
return example_batch, label_batch
if __name__ == '__main__':
filelist = os.listdir("./data/csvdata")
filelist = ["./data/csvdata/{}".format(i) for i in filelist]
example_batch, label_batch = readcsv(filelist)
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
# 線程協調器
coord = tf.train.Coordinator()
# 開啟讀取文件線程
threads = tf.train.start_queue_runners(sess, coord=coord)
# 打印數據
print(sess.run([example_batch, label_batch]))
coord.request_stop()
coord.join()