TFRecords轉化和讀取


標准TensorFlow格式

TensorFlow的訓練過程其實就是大量的數據在網絡中不斷流動的過程,而數據的來源在官方文檔[^1](API r1.2)中介紹了三種方式,分別是:

  • Feeding。通過Python直接注入數據。
  • Reading from files。從文件讀取數據,本文中的TFRecord屬於此類方式。
  • Preloaded data。將數據以constant或者variable的方式直接存儲在運算圖中。

當數據量較大時,官方推薦采用標准TensorFlow格式[^2](Standard TensorFlow format)來存儲訓練與驗證數據,該格式的后綴名為tfrecord。官方介紹如下:

A TFRecords file represents a sequence of (binary) strings. The format is not random access, so it is suitable for streaming large amounts of data but not suitable if fast sharding or other non-sequential access is desired.

從介紹不難看出,TFRecord文件適用於大量數據的順序讀取。而這正好是神經網絡在訓練過程中發生的事情。


如何使用TFRecord文件

對於TFRecord文件的使用,官方給出了兩份示例代碼,分別展示了如何生成與讀取該格式的文件。

生成TFRecord文件

第一份代碼convert_to_records.py [^3]將MNIST里的圖像數據轉換為了TFRecord格式 。仔細研讀代碼,可以發現TFRecord文件中的圖像數據存儲在Feature下的image_raw里。image_raw來自於data_set.images,而后者又來自mnist.read_data_sets()。因此images的真身藏在mnist.py這個文件里。

mnist.py並不難找,在Pycharm里按下ctrl后單擊鼠標左鍵即可打開源代碼。

繼續追蹤,可以在mnist里發現圖像來自extract_images()函數。該函數的說明里清晰的寫明:

Extract the images into a 4D uint8 numpy array [index, y, x, depth].
  Args:
    f: A file object that can be passed into a gzip reader.
  Returns:
    data: A 4D uint8 numpy array [index, y, x, depth].
  Raises:
    ValueError: If the bytestream does not start with 2051.

很明顯,返回值變量名為data,是一個4D Numpy矩陣,存儲值為uint8類型,即圖像像素的灰度值(MNIST全部為灰度圖像)。四個維度分別代表了:圖像的個數,每個圖像行數,每個圖像列數,每個圖像通道數。

在獲得這個存儲着像素灰度值的Numpy矩陣后,使用numpy的tostring()函數將其轉換為Python bytes格式[^4],再使用tf.train.BytesList()函數封裝為tf.train.BytesList類,名字為image_raw。最后使用tf.train.Example()image_raw和其它屬性一遍打包,並調用tf.python_io.TFRecordWriter將其寫入到文件中。

至此,TFRecord文件生成完畢。

可見,將自定義圖像轉換為TFRecord的過程本質上是將大量圖像的像素灰度值轉換為Python bytes,並與其它Feature組合在一起,最終拼接成一個文件的過程。

需要注意的是其它Feature的類型不一定必須是BytesList,還可以是Int64List或者FloatList。

讀取TFRecord文件

第二份代碼fully_connected_reader.py [1]展示了如何從TFRecord文件中讀取數據。

讀取數據的函數名為input()。函數內部首先通過tf.train.string_input_producer()函數讀取TFRecord文件,並返回一個queue;然后使用read_and_decode()讀取一份數據,函數內部用tf.decode_raw()解析出圖像的灰度值,用tf.cast()解析出label的值。之后通過tf.train.shuffle_batch()的方法生成一批用來訓練的數據。並最終返回可供訓練的imageslabels,並送入inference部分進行計算。

在這個過程中,有以下幾點需要留意:

  1. tf.decode_raw()解析出的數據是沒有shape的,因此需要調用set_shape()函數來給出tensor的維度。
  2. read_and_decode()函數返回的是單個的數據,但是后邊的tf.train.shuffle_batch()卻能夠生成批量數據。
  3. 如果需要對圖像進行處理的話,需要放在第二項提到的兩個函數中間。

其中第2點的原理我暫時沒有弄懂。從代碼上看read_and_decode()返回的是單個數據,shuffle_batch接收到的也是單個數據,不知道是如何生成批量數據的,猜測與queue有關系。

所以,讀取TFRecord文件的本質,就是通過隊列的方式依次將數據解碼,並按需要進行數據隨機化、圖像隨機化的過程。


參考


  1. Github: fully_connected_reader.py ↩︎


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM