3. Tensorflow生成TFRecord


1. Tensorflow高效流水線Pipeline

2. Tensorflow的數據處理中的Dataset和Iterator

3. Tensorflow生成TFRecord

4. Tensorflow的Estimator實踐原理

1. 前言

TFRecord是TensorFlow官方推薦使用的數據格式化存儲工具,它不僅規范了數據的讀寫方式,還大大地提高了IO效率。

2. TFRecord原理步驟

TFRecord內部使用了“Protocol Buffer”二進制數據編碼方案,只要生成一次TFRecord,之后的數據讀取和加工處理的效率都會得到提高。

而且,使用TFRecord可以直接作為Cloud ML Engine的輸入數據。

一般來說,我們使用TensorFlow進行數據讀取的方式有以下4種:

  1. 預先把所有數據加載進內存
  2. 在每輪訓練中使用原生Python代碼讀取一部分數據,然后使用feed_dict輸入到計算圖
  3. 利用Threading和Queues從TFRecord中分批次讀取數據
  4. 使用Dataset API

(1)方案對於數據量不大的場景來說是足夠簡單而高效的,但是隨着數據量的增長,勢必會對有限的內存空間帶來極大的壓力,還有長時間的數據預加載,甚至導致我們十分熟悉的OutOfMemoryError;

(2)方案可以一定程度上緩解了方案(1)的內存壓力問題,但是由於在單線程環境下我們的IO操作一般都是同步阻塞的,勢必會在一定程度上導致學習時間的增加,尤其是相同的數據需要重復多次讀取的情況下;

而方案(3)和方案(4)都利用了我們的TFRecord,由於使用了多線程使得IO操作不再阻塞我們的模型訓練,同時為了實現線程間的數據傳輸引入了Queues。

2.1 生成TFRecord數據

example = tf.train.Example()這句將數據賦給了變量example(可以看到里面是通過字典結構實現的賦值),然后用writer.write(example.SerializeToString()) 這句實現寫入。

tfrecord_filename = './tfrecord/train.tfrecord'
# 創建.tfrecord文件,准備寫入
writer = tf.python_io.TFRecordWriter(tfrecord_filename)
for i in range(100):
    img_raw = np.random.random_integers(0,255,size=(7,30)) # 創建7*30,取值在0-255之間隨機數組
    img_raw = img_raw.tostring()
    example = tf.train.Example(features=tf.train.Features(
            feature={
            # Int64List儲存int數據
            'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])), 
            # 儲存byte二進制數據
            'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
            }))
    # 序列化過程
    writer.write(example.SerializeToString()) 
writer.close()

值得注意的是賦值給example的數據格式。從前面tf.train.Example的定義可知,tfrecord支持整型、浮點數和二進制三種格式,分別如下:

tf.train.Feature(int64_list = tf.train.Int64List(value=[int_scalar]))
tf.train.Feature(bytes_list = tf.train.BytesList(value=[array_string_or_byte]))
tf.train.Feature(float_list = tf.train.FloatList(value=[float_scalar]))

例如圖片等數組形式(array)的數據,可以保存為numpy array的格式,轉換為string,然后保存到二進制格式的feature中。對於單個的數值(scalar),可以直接賦值。這里value=[×]的[]非常重要,也就是說輸入的必須是列表(list)。當然,對於輸入數據是向量形式的,可以根據數據類型(float還是int)分別保存。並且在保存的時候還可以指定數據的維數。

2.2 讀取TFRecord數據

從TFRecord文件中讀取數據, 首先需要用tf.train.string_input_producer生成一個解析隊列。之后調用tf.TFRecordReader的tf.parse_single_example解析器。如下圖:

image

具體代碼如下:

def read_tfrecord(filename):
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(
        serialized_example,
        features={
            'sentence': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64)
        })

    sentence, label = tf.train.batch([features['sentence'], features['label']],
            batch_size=16,
            capacity=64)

    return sentence, label

3. 總結

TFRecord的生成效率可能不是很快(可以使用多進程),但是一旦TFRecord數據處理好了,對以后每次的讀取,解析都有速度上的提升。而且TFRecord也可以和Tensorflow自帶的數據處理方式Dataset搭配使用,基本可以解決大數據量的訓練操作。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM