模塊作用
tf.data api用於創建訓練前導入數據和數據處理的pipeline,使得處理大規模數據,不同數據格式和復雜數據處理變的容易。
基本抽象
提供了兩種基本抽象:Dataset和Iterator
Dataset
表示元素序列集合,每個元素包含一個或者多個Tensor對象,每個元素是一個樣本。有兩種方式可以創建Dataset。
- 從源數據創建,比如:Dataset.from_tensor_slices()
- 通過數據處理轉換創建,比如 Dataset.map()/batch()
Iterator
用於從Dataset獲取數據給訓練作為輸入,作為輸入管道和訓練代碼的接口。最簡單的例子是“one-shot iterator”,和一個dataset綁定並返回它的元素一次。通過調整初始化參數,可以實現一些復雜場景的需求,比如循環迭代訓練集的元素。
基本機制
基本使用的流程如下:
1. 創建一個dataset
2. 做一些數據處理(map, batch)
3. 創建一個iterator提供給訓練使用
數據結構
每個元素都有相同的數據結構,每個元素包含一個或者多個tensor,可以元組表示或者嵌套表示。每個tensor包含類型tf.DType和維度形狀tf.TensorShape。每個tensor可以有名稱,通過collections.namedtuple或者字典實現。
Dataset的數據處理接口能夠支持任意數據結構。
創建Iterator
有了dataset以后,下一步就是創建一個iterator提供訪問數據元素的接口,現在tf.data api支持4中iterator來實現不同程度的復雜場景。
- one-shot,
- initializable,
- reinitializable
- feedable.
one-shot
最簡單,最常用,就是一次遍歷所有元素,目前(2019/08)也是estimator唯一支持的iterator。
initializable
需要顯示運行初始化操作,可以通過tf.placeholder參數化dataset。
reinitializalbe
可以從多個dataset多次初始化,當然需要相同數據結構。每次切換數據集,使用前需要初始化,
feedable
可以從多個dataset多次初始化,初始化完畢后,可以通過tf.placeholder隨時切換數據集。
獲取iterator的值
通過session.run(iterator.get_next())完成,若沒有數據了則報異常:tf.errors.OutOfRangeError。
iterator.get_next()返回的是tf.Tensor對象,需要在session中執行才會有值。
保存iterator狀態
save and restore the current state of the iterator (and, effectively, the whole input pipeline). 能將整個輸入pipeline保存並恢復(原理是什么?)
saveable = tf.contrib.data.make_saveable_from_iterator(iterator)
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver()
讀取輸入數據
直接讀取numpy的數據得到array,調用Dataset.from_tensor_slices
讀取TFRecord data: dataset = tf.data.TFRecordDataset(filenames)
通過文件的方式,可以支持不能夠全部導入內存中的場景。tf.data.TFRcordDataset將讀取數據流作為整個輸入流的一部分。
同時,文件名可以通過tf.placeholder傳入。
這里所有的dataset的返回,也都是tensorflow中的operation,需要在session中計算得到值。
解析tf.Example
推薦情形下,輸入需要從TFRecord文件格式讀取TF Example的protocol buffer messages,每個TF Example包含一個或者多個features,輸入管道需要將其轉換為tensor。
典型列子
# Transforms a scalar string `example_proto` into a pair of a scalar string and# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""), "label": tf.FixedLenFeature((), tf.int64, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features) return parsed_features["image"], parsed_features["label"]
# Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)
對於圖像,可以通過tf.image.decode_jpeg等方式提取和resize。
任意的python邏輯處理數據情形
在文本當中,比如有時候需要調用其他python庫,比如jieba,這時候輸入處理需要通過t'f.py_func()opeation完成。
Batching dataset elements
簡單batch
調用Dataset.batch()實現,所有元素包含同樣的數據結構。
帶Padding的batch
當所有元素包含不同長度的數據結構時,可以通過padding使得單個batch數據長度一致。序列模型情形下。制定padding的shape為可變長度,比如(None,)來完成,會自動padding到單個batch的最大長度。當然,可以重寫padding的值或者padding的長度。
訓練流程
多批次遍歷
tf.data提供了兩種遍歷數據的方式
想要遍歷數據多次,可以使用Dataset.repeat(n)操作完成。遍歷n次對應多少次epochs。若不提供參數,則將無限循環。同時,將不會給出一個批次結束的信號。
若想在每個批次結束后做處理,則需要在外圍加循環,然后通過重復初始化迭代器完成,捕捉tf.errors.OutOfRangeError異常。
隨機shuffle
dataset.shuffle()維持一個緩存區,隨機取下一個數據。
高階API
tf.train.MonitoredTrainingSession 接口簡化了分布式Tensorflow的很多方面。MonitoredTrainingSession 用tf.error.OutOfRangeError來獲取訓練是否結束,所以推薦采用one-shot 迭代器。
用dataset作為estimator的輸入函數時,直接將dataset返回,estimator會自動創建迭代器並初始化。