tf.data 數據處理模塊


tf.data.Dataset API非常豐富,主要包括創建數據集、應用transform、數據迭代等。

一、Dataset類初覽

最簡單的方法是根據python列表來創建:

處理文件數據,利用tf.data.TextLineDataset:

對於TFRecord格式可以利用TFRecordDataset:

對於匹配所有文件格式的數據,可以利用tf.data.Dataset.list_files:

 

Transformations

有了數據可以利用map函數來transform數據:

Dataset支持哪些類型:

包括嵌套的元組、具名元組、字典等。元素可以為任何類型:

tf.Tensortf.data.Datasettf.SparseTensor,tf.RaggedTensor, 和 tf.TensorArray. 

從上面可以看到,Dataset有一個參數:variant_tensor, 具有一個表示元素類型的屬性:element_spec 

下面詳細介紹Dataset類方法 

 

二、Dataset類的方法(共26個)

1.  __iter__

顧名思義,返回該數據集的迭代器。並可以在eager模式下使用。

 

2. apply

apply(
    transformation_func
)

對數據應用transformation

 

 

3.  as_numpy_iterator(貌似2.0.0版本沒有該方法)

返回一個將數據元素轉換為numpy的迭代器,方便只查看元素。這個操作比直接打印print少了元素類型和類型:

這個方法需要在eager模式下才行, 只顯示數據本身:

as_numpy_iterator() 將保留數據元素的原始嵌套格式:

如果數據中含有非Tensor值報錯TypeError,若在非eager模式下用會報錯RuntimeError。 

 

 

4.  batch

batch(
    batch_size, drop_remainder=False
)

該方法將數據組成批量。

參數drop_remainder類似於pytorch中的drop_last:

 

 

5. cache

cache(
    filename=''
)

緩存數據,當前據迭代完成,元素會在特定地方實現緩存,后續迭代會利用緩存的數據。

當緩存到文件時,在整個運行過程緩存數據將保持,首次迭代 也將從緩存文件中讀取數據。如果在.cache()調用之前改變了數據源,將不會有任何影響。除非cache文件被移除或者文件名更換:

第二次雖然改變了源數據,仍打印出原始數據的內容。 如果調用該函數時沒有提供文件名,則數據將緩存到memory中。

 

6. concatenate

concatenate(
    dataset
)

通過連接給定的數據集得到新數據集,注意類型要一致。

 

 

7. enumerate

enumerate(
    start=0
)

按要求枚舉數據,和python的enumerate類似。

 

 

8. filter

filter(
    predicate
)

過濾數據集,輸入為函數(映射數據為布爾類型)

 

 

9. flat_map

flat_map(
    map_func
)

拉伸數據。如果要確保數據集的順序保持不變可以用該函數,例如將批量數據拉伸至元素級別:

 

 

10.  from_generator

@staticmethod
from_generator(
    generator, output_types, output_shapes=None, args=None
)

建立一個數據集,其中的元素由生成器generator產生。generator的參數必須是可callable的類,返回支持iter()的類。產生的元素必須與output_types一致,output_shapes參數可選。

 

 

11. from_tensor_slices

@staticmethod
from_tensor_slices(
    tensors
)

這個方法早在前面許多例子中用到了,從給定tensor切片中創建數據集。從第一維度進行slice,保留了輸入tensor的結構,移除每個tensor的第一維度並作為數據集的維度。所有的輸入tensor必須有相同的第一維度。

利用zip將不同dataset打包到一起:

輸出:

兩個tensor只要第一維一樣就可以結合到一個dataset中:

 

 

12. from_tensor

@staticmethod
from_tensors(
    tensors
)

與上面不同的是不含切片,只是將整個tensor作為一個dataset。例如:

和上一個方法的一個共同點:如果輸入tensors中包含numpy數組,並且eager模型未開啟,則將會被嵌入到graphs中作為一個或多個tf.constant.對於大型數據集(>1GB),這可能會浪費存儲。如果tensors中包含一個或多個大型numpy數組,可以考慮利用這里this guide.的操作。

 

13. interleave

interleave(
    map_func, cycle_length=-1, block_length=1, num_parallel_calls=None
)

將map_func映射到整個數據集。並分發結果。

 

 

14. list_files

@staticmethod
list_files(
    file_pattern, shuffle=None, seed=None
)

匹配一個或更多的glob模式,file_pattern參數應當小於glob patterns,否則可以用Dataset.from_tensor_slices(filenames) 就好。

 

 

15.  map

map(
    map_func, num_parallel_calls=None
)

這個函數也已經用了多次,將map_func 應用到整個數據集中。

 

16.  padded_batch

padded_batch(
    batch_size, padded_shapes, padding_values=None, drop_remainder=False
)

此轉換將輸入數據集的多個連續元素合並為一個元素。類似於tf.data.Dataset.batch,將會有一個新增的batch維度,不同的是此時輸入的元素可能shape不同,該轉換將會pad每個元素來得到應有的padding_shapes。這個參數決定了最后的輸出批量維度。如果維度是一個常數e.g. tf.compat.v1.Dimension(37),元素將會在該維度被pad到該長度,如果維度是未知的e.g. tf.compat.v1.Dimension(None),將會被pad到所有元素的最大長度。

 

 

17. prefetch

prefetch(
    buffer_size
)

從數據集中建立預讀取元素。大多數數據集輸入結構都應該以預讀取prefetch結束。這允許在處理當前元素時准備后面的元素。這通常會提高延遲和吞吐量,代價是使用額外的內存來存儲預取的元素。

和batch方法一起使用:

examples.prefetch(2) will prefetch two elements (2 examples), while examples.batch(20).prefetch(2) will prefetch 2 elements (2 batches, of 20 examples each). 

利用prefetch和num_parallel_calls 參數,模型訓練的時間可縮減至原來的一半甚至更低:

1     train_dataset = tf.data.Dataset.from_tensor_slices((train_filenames, train_labels))
2     train_dataset = train_dataset.map(
3         map_func=_decode_and_resize, 
4         num_parallel_calls=tf.data.experimental.AUTOTUNE)
5     # 取出前buffer_size個數據放入buffer,並從其中隨機采樣,采樣后的數據用后續數據替換
6     train_dataset = train_dataset.shuffle(buffer_size=23000)    
7     train_dataset = train_dataset.batch(batch_size)
8     train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)
View Code

                    tf.data 的並行化策略性能測試(縱軸為每 epoch 訓練所需時間,單位:秒)

 

 

18. range

@staticmethod
range(
    *args
)

也已經用過多次了:建立一定范圍內的元素數據集

 

 

19. reduce

reduce(
    initial_state, reduce_func
)

將輸入元素整合成單一元素。該轉換將會已知在每個元素上調用reduce_func函數,直到遍歷數據集結束。initial_state參數用於初始狀態。

注意reduce_func參數需要兩個參數為 (old_state, input_element),這兩個茶樹會被映射到new_state,當然最開始的old_state就是initial_state,所以這些state的格式應當一致。最終返回的就是final state。這樣就好理解上圖中的例子了。

 

20. repeat

repeat(
    count=None
)

就是按照重復次數來重復輸入元素。

 

 

21. shard

shard(
    num_shards, index
)

創建一個僅包含1/num_shards原有數據集大小的數據集。 index實現開始索引。

在分布式訓練的時候很有用,因為者可以划分給每個設備一個子集。當讀取到一個單一的輸入文件時,可以這樣做:

重要注意事項:在使用任何隨機化操作符(如shuffle)之前,一定要切分。通常,最好在數據集管道的早期使用shard操作符。例如,從一組TFRecord文件中讀取時,在將數據集轉換為輸入樣本之前切分。這樣就避免了讀取每個工人的每個文件。下面是一個完整管道內高效分片策略的示例:

 

 

21. shuffle

shuffle(
    buffer_size, seed=None, reshuffle_each_iteration=None
)

隨即打散輸入數據。數據填buffer_size大小的元素到buffer中,然后在該buffer中進行隨機采樣。對於完美的打散計划,buffer尺寸應大於等於所需的數據集尺寸。例如你的數據集有10000個元素,但是buffer_size設置為1000,然后僅會從這1000個元素中進行隨機選擇。一旦某個元素被選定,其位置就會被下個(額外的)元素取代從而保持buffer大小為1000。參數reshuffle_each_iteration 控制是否不同epoch保持相同的shuffle順序。在TF1.X版本中,慣用的方法是通過repeat轉換:

在TF2.0版本中,tf.data.Dataset是python可迭代的,所以通過python迭代也可以創建批量:

 

22. skip

 

skip(
    count
)

創建一個數據集:跳過count參數之前的元素:

如果count參數大於當前數據集的大小,新的數據集將不包含任何數據。如果將其設為-1,則包含整個數據。

 

23.  take

take(
    count
)

創建一個數據集:最多包含count數目大小的數據集:

如果count=-1 或者count大於整個數據集尺寸,新的數據集將包含整個數據集。 

 

24. unbatch

將數據集划分到多個元素。就是batch的反向操作,最后結果是分解掉了batch的維度:

 

 

25. window

window(
    size, shift=None, stride=1, drop_remainder=False
)

結合輸入元素到windows,windows指的是一個有限的數據集,尺寸為size或更小:如果沒有足夠輸入元素來填充這個window或者drop_remainder參數為False。stride參數決定了輸入元素的步長,shift參數決定window的偏移。后三個參數都是可選。size表示形成window所需要結合的數據元素數目(窗口大小)。shift表示每次迭代的滑動數目。stride表示每個窗口中元素步長。最后一個參數表示是否丟棄當前窗口,如果其尺寸小於指定的size。

 

 

26. zip

@staticmethod
zip(
    datasets
)

打包多個數據集,用到多次了。和python基本一樣,差別在於datasets參數可以實任意嵌套的Dataset類。

 

整理編輯自:https://www.tensorflow.org/api_docs/python/tf/data/Dataset


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM