標准TensorFlow格式
TensorFlow的訓練過程其實就是大量的數據在網絡中不斷流動的過程,而數據的來源在官方文檔[^1](API r1.2)中介紹了三種方式,分別是:
- Feeding。通過Python直接注入數據。
- Reading from files。從文件讀取數據,本文中的TFRecord屬於此類方式。
- Preloaded data。將數據以constant或者variable的方式直接存儲在運算圖中。
當數據量較大時,官方推薦采用標准TensorFlow格式[^2](Standard TensorFlow format)來存儲訓練與驗證數據,該格式的后綴名為tfrecord
。官方介紹如下:
A TFRecords file represents a sequence of (binary) strings. The format is not random access, so it is suitable for streaming large amounts of data but not suitable if fast sharding or other non-sequential access is desired.
從介紹不難看出,TFRecord文件適用於大量數據的順序讀取。而這正好是神經網絡在訓練過程中發生的事情。
如何使用TFRecord文件
對於TFRecord文件的使用,官方給出了兩份示例代碼,分別展示了如何生成與讀取該格式的文件。
生成TFRecord文件
第一份代碼convert_to_records.py
[^3]將MNIST里的圖像數據轉換為了TFRecord格式 。仔細研讀代碼,可以發現TFRecord文件中的圖像數據存儲在Feature
下的image_raw
里。image_raw
來自於data_set.images
,而后者又來自mnist.read_data_sets()
。因此images
的真身藏在mnist.py
這個文件里。
mnist.py
並不難找,在Pycharm里按下ctrl
后單擊鼠標左鍵即可打開源代碼。
繼續追蹤,可以在mnist里發現圖像來自extract_images()
函數。該函數的說明里清晰的寫明:
Extract the images into a 4D uint8 numpy array [index, y, x, depth].
Args:
f: A file object that can be passed into a gzip reader.
Returns:
data: A 4D uint8 numpy array [index, y, x, depth].
Raises:
ValueError: If the bytestream does not start with 2051.
很明顯,返回值變量名為data
,是一個4D Numpy矩陣,存儲值為uint8
類型,即圖像像素的灰度值(MNIST全部為灰度圖像)。四個維度分別代表了:圖像的個數,每個圖像行數,每個圖像列數,每個圖像通道數。
在獲得這個存儲着像素灰度值的Numpy矩陣后,使用numpy的tostring()
函數將其轉換為Python bytes格式[^4],再使用tf.train.BytesList()
函數封裝為tf.train.BytesList
類,名字為image_raw
。最后使用tf.train.Example()
將image_raw
和其它屬性一遍打包,並調用tf.python_io.TFRecordWriter
將其寫入到文件中。
至此,TFRecord文件生成完畢。
可見,將自定義圖像轉換為TFRecord的過程本質上是將大量圖像的像素灰度值轉換為Python bytes,並與其它Feature
組合在一起,最終拼接成一個文件的過程。
需要注意的是其它Feature
的類型不一定必須是BytesList,還可以是Int64List或者FloatList。
讀取TFRecord文件
第二份代碼fully_connected_reader.py
[1]展示了如何從TFRecord文件中讀取數據。
讀取數據的函數名為input()
。函數內部首先通過tf.train.string_input_producer()
函數讀取TFRecord文件,並返回一個queue
;然后使用read_and_decode()
讀取一份數據,函數內部用tf.decode_raw()
解析出圖像的灰度值,用tf.cast()
解析出label
的值。之后通過tf.train.shuffle_batch()
的方法生成一批用來訓練的數據。並最終返回可供訓練的images
和labels
,並送入inference
部分進行計算。
在這個過程中,有以下幾點需要留意:
tf.decode_raw()
解析出的數據是沒有shape
的,因此需要調用set_shape()
函數來給出tensor的維度。read_and_decode()
函數返回的是單個的數據,但是后邊的tf.train.shuffle_batch()
卻能夠生成批量數據。- 如果需要對圖像進行處理的話,需要放在第二項提到的兩個函數中間。
其中第2點的原理我暫時沒有弄懂。從代碼上看read_and_decode()
返回的是單個數據,shuffle_batch
接收到的也是單個數據,不知道是如何生成批量數據的,猜測與queue
有關系。
所以,讀取TFRecord文件的本質,就是通過隊列的方式依次將數據解碼,並按需要進行數據隨機化、圖像隨機化的過程。