1. Tensorflow高效流水線Pipeline
2. Tensorflow的數據處理中的Dataset和Iterator
3. Tensorflow生成TFRecord
4. Tensorflow的Estimator實踐原理
1. 前言
TFRecord是TensorFlow官方推薦使用的數據格式化存儲工具,它不僅規范了數據的讀寫方式,還大大地提高了IO效率。
2. TFRecord原理步驟
TFRecord內部使用了“Protocol Buffer”二進制數據編碼方案,只要生成一次TFRecord,之后的數據讀取和加工處理的效率都會得到提高。
而且,使用TFRecord可以直接作為Cloud ML Engine的輸入數據。
一般來說,我們使用TensorFlow進行數據讀取的方式有以下4種:
- 預先把所有數據加載進內存
- 在每輪訓練中使用原生Python代碼讀取一部分數據,然后使用feed_dict輸入到計算圖
- 利用Threading和Queues從TFRecord中分批次讀取數據
- 使用Dataset API
(1)方案對於數據量不大的場景來說是足夠簡單而高效的,但是隨着數據量的增長,勢必會對有限的內存空間帶來極大的壓力,還有長時間的數據預加載,甚至導致我們十分熟悉的OutOfMemoryError;
(2)方案可以一定程度上緩解了方案(1)的內存壓力問題,但是由於在單線程環境下我們的IO操作一般都是同步阻塞的,勢必會在一定程度上導致學習時間的增加,尤其是相同的數據需要重復多次讀取的情況下;
而方案(3)和方案(4)都利用了我們的TFRecord,由於使用了多線程使得IO操作不再阻塞我們的模型訓練,同時為了實現線程間的數據傳輸引入了Queues。
2.1 生成TFRecord數據
example = tf.train.Example()這句將數據賦給了變量example(可以看到里面是通過字典結構實現的賦值),然后用writer.write(example.SerializeToString()) 這句實現寫入。
tfrecord_filename = './tfrecord/train.tfrecord'
# 創建.tfrecord文件,准備寫入
writer = tf.python_io.TFRecordWriter(tfrecord_filename)
for i in range(100):
img_raw = np.random.random_integers(0,255,size=(7,30)) # 創建7*30,取值在0-255之間隨機數組
img_raw = img_raw.tostring()
example = tf.train.Example(features=tf.train.Features(
feature={
# Int64List儲存int數據
'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),
# 儲存byte二進制數據
'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
}))
# 序列化過程
writer.write(example.SerializeToString())
writer.close()
值得注意的是賦值給example的數據格式。從前面tf.train.Example的定義可知,tfrecord支持整型、浮點數和二進制三種格式,分別如下:
tf.train.Feature(int64_list = tf.train.Int64List(value=[int_scalar]))
tf.train.Feature(bytes_list = tf.train.BytesList(value=[array_string_or_byte]))
tf.train.Feature(float_list = tf.train.FloatList(value=[float_scalar]))
例如圖片等數組形式(array)的數據,可以保存為numpy array的格式,轉換為string,然后保存到二進制格式的feature中。對於單個的數值(scalar),可以直接賦值。這里value=[×]的[]非常重要,也就是說輸入的必須是列表(list)。當然,對於輸入數據是向量形式的,可以根據數據類型(float還是int)分別保存。並且在保存的時候還可以指定數據的維數。
2.2 讀取TFRecord數據
從TFRecord文件中讀取數據, 首先需要用tf.train.string_input_producer生成一個解析隊列。之后調用tf.TFRecordReader的tf.parse_single_example解析器。如下圖:
具體代碼如下:
def read_tfrecord(filename):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'sentence': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
})
sentence, label = tf.train.batch([features['sentence'], features['label']],
batch_size=16,
capacity=64)
return sentence, label
3. 總結
TFRecord的生成效率可能不是很快(可以使用多進程),但是一旦TFRecord數據處理好了,對以后每次的讀取,解析都有速度上的提升。而且TFRecord也可以和Tensorflow自帶的數據處理方式Dataset搭配使用,基本可以解決大數據量的訓練操作。