TensorFlow Data模塊


模塊作用

tf.data api用於創建訓練前導入數據和數據處理的pipeline,使得處理大規模數據,不同數據格式和復雜數據處理變的容易。

基本抽象

提供了兩種基本抽象:DatasetIterator

Dataset

表示元素序列集合,每個元素包含一個或者多個Tensor對象,每個元素是一個樣本。有兩種方式可以創建Dataset。

  1. 從源數據創建,比如:Dataset.from_tensor_slices()
  2. 通過數據處理轉換創建,比如 Dataset.map()/batch()

Iterator

用於從Dataset獲取數據給訓練作為輸入,作為輸入管道和訓練代碼的接口。最簡單的例子是“one-shot iterator”,和一個dataset綁定並返回它的元素一次。通過調整初始化參數,可以實現一些復雜場景的需求,比如循環迭代訓練集的元素。

基本機制

基本使用的流程如下:

1. 創建一個dataset
2. 做一些數據處理(map, batch)
3. 創建一個iterator提供給訓練使用

數據結構

每個元素都有相同的數據結構,每個元素包含一個或者多個tensor,可以元組表示或者嵌套表示。每個tensor包含類型tf.DType和維度形狀tf.TensorShape。每個tensor可以有名稱,通過collections.namedtuple或者字典實現。

Dataset的數據處理接口能夠支持任意數據結構。

創建Iterator

有了dataset以后,下一步就是創建一個iterator提供訪問數據元素的接口,現在tf.data api支持4中iterator來實現不同程度的復雜場景。

  • one-shot,
  • initializable,
  • reinitializable
  • feedable.

one-shot

最簡單,最常用,就是一次遍歷所有元素,目前(2019/08)也是estimator唯一支持的iterator。

initializable

需要顯示運行初始化操作,可以通過tf.placeholder參數化dataset。

reinitializalbe

可以從多個dataset多次初始化,當然需要相同數據結構。每次切換數據集,使用前需要初始化,

feedable

可以從多個dataset多次初始化,初始化完畢后,可以通過tf.placeholder隨時切換數據集。

獲取iterator的值

通過session.run(iterator.get_next())完成,若沒有數據了則報異常:tf.errors.OutOfRangeError。

iterator.get_next()返回的是tf.Tensor對象,需要在session中執行才會有值。

保存iterator狀態

save and restore the current state of the iterator (and, effectively, the whole input pipeline). 能將整個輸入pipeline保存並恢復(原理是什么?

saveable = tf.contrib.data.make_saveable_from_iterator(iterator)
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver()

讀取輸入數據

直接讀取numpy的數據得到array,調用Dataset.from_tensor_slices

讀取TFRecord data: dataset = tf.data.TFRecordDataset(filenames)

通過文件的方式,可以支持不能夠全部導入內存中的場景。tf.data.TFRcordDataset將讀取數據流作為整個輸入流的一部分。

同時,文件名可以通過tf.placeholder傳入。

這里所有的dataset的返回,也都是tensorflow中的operation,需要在session中計算得到值。

解析tf.Example

推薦情形下,輸入需要從TFRecord文件格式讀取TF Example的protocol buffer messages,每個TF Example包含一個或者多個features,輸入管道需要將其轉換為tensor。

典型列子

# Transforms a scalar string `example_proto` into a pair of a scalar string and# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto): 
    features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),              "label": tf.FixedLenFeature((), tf.int64, default_value=0)} 
    parsed_features = tf.parse_single_example(example_proto, features)  return parsed_features["image"], parsed_features["label"]
# Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)

對於圖像,可以通過tf.image.decode_jpeg等方式提取和resize。

任意的python邏輯處理數據情形

在文本當中,比如有時候需要調用其他python庫,比如jieba,這時候輸入處理需要通過t'f.py_func()opeation完成。

Batching dataset elements

簡單batch

調用Dataset.batch()實現,所有元素包含同樣的數據結構。

帶Padding的batch

當所有元素包含不同長度的數據結構時,可以通過padding使得單個batch數據長度一致。序列模型情形下。制定padding的shape為可變長度,比如(None,)來完成,會自動padding到單個batch的最大長度。當然,可以重寫padding的值或者padding的長度。

訓練流程

多批次遍歷

tf.data提供了兩種遍歷數據的方式

想要遍歷數據多次,可以使用Dataset.repeat(n)操作完成。遍歷n次對應多少次epochs。若不提供參數,則將無限循環。同時,將不會給出一個批次結束的信號。

若想在每個批次結束后做處理,則需要在外圍加循環,然后通過重復初始化迭代器完成,捕捉tf.errors.OutOfRangeError異常。

隨機shuffle

dataset.shuffle()維持一個緩存區,隨機取下一個數據。

高階API

tf.train.MonitoredTrainingSession 接口簡化了分布式Tensorflow的很多方面。MonitoredTrainingSession 用tf.error.OutOfRangeError來獲取訓練是否結束,所以推薦采用one-shot 迭代器。

用dataset作為estimator的輸入函數時,直接將dataset返回,estimator會自動創建迭代器並初始化。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM