1. Tensorflow高效流水線Pipeline
2. Tensorflow的數據處理中的Dataset和Iterator
3. Tensorflow生成TFRecord
4. Tensorflow的Estimator實踐原理
1. 前言
我們在訓練模型的時候,必須經過的第一個步驟是數據處理。在機器學習領域有一個說法,數據處理的好壞直接影響了模型結果的好壞。數據處理是至關重要的一步。
我們今天關注數據處理的另一個問題:假設我們做深度學習,數據的量隨隨便便就到GB的級別,那數據處理的速度對於模型的訓練也很重要。經常遇到的一個情況是,數據處理的時間占了訓練整個模型的大部分。
今天介紹的是Tensorflow官方推薦的數據處理方式是用Dataset API同時支持從內存和硬盤的讀取,相比之前的兩種方法在語法上更加簡潔易懂
2. Dataset原理
Google官方給出的Dataset API中的類圖如下所示:

2.1 Dataset創建方法
Dataset API還提供了四種創建Dataset的方式:
- tf.data.Dataset.from_tensor_slices():這個函數直接從內存中讀取數據,數據的形式可以是數組、矩陣、dict等。
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
#實例化make_one_shot_iterator對象,該對象只能讀取一次
iterator = dataset.make_one_shot_iterator()
# 從iterator里取出一個元素
one_element = iterator.get_next()
with tf.Session() as sess:
for i in range(5):
print(sess.run(one_element))
- tf.data.TFRecordDataset():顧名思義,這個函數是用來讀TFRecord文件的,dataset中的每一個元素就是一個TFExample。
# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
- tf.data.TextLineDataset():這個函數的輸入是一個文件的列表,輸出是一個dataset。dataset中的每一個元素就對應了文件中的一行。可以使用這個函數來讀入CSV文件。
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)
- tf.data.FixedLengthRecordDataset():這個函數的輸入是一個文件的列表和一個record_bytes,之后dataset的每一個元素就是文件中固定字節數record_bytes的內容。通常用來讀取以二進制形式保存的文件,如CIFAR10數據集就是這種形式。
2.2 Dataset數據進行轉換(Transformation)
一個Dataset通過Transformation變成一個新的Dataset。通常我們可以通過Transformation完成數據變換,打亂,組成batch,生成epoch等一系列操作,常用的Transformation有:
- map:接收一個函數對象,Dataset中的每個元素都會被當作這個函數的輸入,並將函數返回值作為新的Dataset,如我們可以對dataset中每個元素的值加1。
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0
- apply:應用一個轉換函數到dataset。
dataset = dataset.apply(group_by_window(key_func, reduce_func, window_size))
- batch:根據接收的整數值將該數個元素組合成batch,如下面的程序將dataset中的元素組成了大小為32的batch。
dataset = dataset.batch(32)
- shuffle:打亂dataset中的元素,它有一個參數buffersize,表示打亂時使用的buffer的大小。
dataset = dataset.shuffle(buffer_size=10000)
- repeat:整個序列重復多次,主要用來處理機器學習中的epoch,假設原先的數據是一個epoch,使用repeat(5)就可以將之變成5個epoch。
dataset = dataset.repeat(5)
# 如果repeat沒有參數,則一直重復循環數據
dataset = dataset.repeat()
- padded_batch:對dataset中的數據進行padding到一定的長度。
dataset.padded_batch(
batch_size,
padded_shapes=(
tf.TensorShape([None]), # src
tf.TensorShape([]), # tgt_output
tf.TensorShape([]),
tf.TensorShape([src_max_len])), # src_len
padding_values=(
src_eos_id, # src
0, # tgt_len -- unused
0, # src_len -- unused
0)) # mask
- shard:根據多GPU進行分片操作。
dataset.shard(num_shards, shard_index)
比較完整的生成dataset的代碼。
def parse_fn(example):
"Parse TFExample records and perform simple data augmentation."
example_fmt = {
"image": tf.FixedLengthFeature((), tf.string, ""),
"label": tf.FixedLengthFeature((), tf.int64, -1)
}
parsed = tf.parse_single_example(example, example_fmt)
image = tf.image.decode_image(parsed["image"])
image = _augment_helper(image) # augments image using slice, reshape, resize_bilinear
return image, parsed["label"]
#簡單的生成input_fn
def input_fn():
files = tf.data.Dataset.list_files("/path/to/dataset/train-*.tfrecord")
dataset = files.interleave(tf.data.TFRecordDataset)
dataset = dataset.shuffle(buffer_size=FLAGS.shuffle_buffer_size)
dataset = dataset.map(map_func=parse_fn)
dataset = dataset.batch(batch_size=FLAGS.batch_size)
return dataset
3. Iterator原理
3.1 Iterator Init初始化
生成Iterator一共有4種,復雜程度遞增,個人覺得掌握前兩種應該夠用了,Iterator還有一個優勢,目前,單次迭代器是唯一易於與 Estimator 搭配使用的類型。
- one shot Iterator:one shot Iterator是最簡單的一種Iterator,僅支持對整個數據集訪問一遍,不需要顯式的初始化。one-shot Iterator不支參數化。
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
for i in range(100):
value = sess.run(next_element)
assert i == value
- initializable Iterator:Initializable Iterator 要求在使用之前顯式的通過調用Iterator.initializer操作初始化,這使得在定義數據集時可以結合tf.placeholder傳入參數。
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value
- reinitializable Iterator:可以被不同的dataset對象初始化,比如對於訓練集進行了shuffle的操作,對於驗證集則沒有處理,通常這種情況會使用兩個具有相同結構的dataset對象。
- feedable Iterator:可以通過和tf.placeholder結合在一起,同通過feed_dict機制來選擇在每次調用tf.Session.run的時候選擇哪種Iterator。
3.2 Iterator get_next遍歷數據
Iterator.get_next() 方法tf.Tensor 對象,每次tf.Session.run(Iterator.get_next())都會獲取底層數據集中下一個元素的值。
如果迭代器到達數據集的末尾,則執行 Iterator.get_next() 操作會產生 tf.errors.OutOfRangeError。在此之后,迭代器將處於不可用狀態;如果需要繼續使用,則必須對其重新初始化。
sess.run(iterator.initializer)
while True:
try:
sess.run(getNextTensor)
except tf.errors.OutOfRangeError:
sess.run(iterator.initializer)
3.3 Iterator Save保存
tf.contrib.data.make_saveable_from_iterator 函數通過迭代器創建一個 SaveableObject,該對象可用於保存和恢復迭代器(實際上是整個輸入管道)的當前狀態。
# Create saveable object from iterator.
saveable = tf.contrib.data.make_saveable_from_iterator(iterator)
# Save the iterator state by adding it to the saveable objects collection.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver()
with tf.Session() as sess:
if should_checkpoint:
saver.save(path_to_checkpoint)
# Restore the iterator state.
with tf.Session() as sess:
saver.restore(sess, path_to_checkpoint)
4. 總結
本文介紹了創建不同種類的Dataset和Iterator對象的基礎知識,熟悉這個數據處理的步驟后,不僅復用性比較強,而且效率也能成倍的提升。
