Tensorflow中的數據對象Dataset



###基礎概念

在tensorflow的官方文檔是這樣介紹Dataset數據對象的:

Dataset可以用來表示輸入管道元素集合(張量的嵌套結構)和“邏輯計划“對這些元素的轉換操作。在Dataset中元素可以是向量,元組或字典等形式。
另外,Dataset需要配合另外一個類Iterator進行使用,Iterator對象是一個迭代器,可以對Dataset中的元素進行迭代提取。

看個簡單的示例:

#創建一個Dataset對象
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])

#創建一個迭代器
iterator = dataset.make_one_shot_iterator()

#get_next()函數可以幫助我們從迭代器中獲取元素
element = iterator.get_next()

#遍歷迭代器,獲取所有元素
with tf.Session() as sess:
   for i in range(9):
       print(sess.run(element))

以上打印結果為:1 2 3 4 5 6 7 8 9


###Dataset方法
####1.from_tensor_slices

from_tensor_slices用於創建dataset,其元素是給定張量的切片的元素。

函數形式:from_tensor_slices(tensors)

參數tensors:張量的嵌套結構,每個都在第0維中具有相同的大小。

具體例子

#創建切片形式的dataset
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])

#創建一個迭代器
iterator = dataset.make_one_shot_iterator()

#get_next()函數可以幫助我們從迭代器中獲取元素
element = iterator.get_next()

#遍歷迭代器,獲取所有元素
with tf.Session() as sess:
   for i in range(3):
       print(sess.run(element))

以上代碼運行結果:1 2 3


####2.from_tensors

創建一個Dataset包含給定張量的單個元素。

函數形式:from_tensors(tensors)

參數tensors:張量的嵌套結構。

具體例子

dataset = tf.data.Dataset.from_tensors([1,2,3,4,5,6,7,8,9])

iterator = concat_dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:
   for i in range(1):
       print(sess.run(element))

以上代碼運行結果:[1,2,3,4,5,6,7,8,9]
即from_tensors是將tensors作為一個整體進行操縱,而from_tensor_slices可以操縱tensors里面的元素。


####3.from_generator

創建Dataset由其生成元素的元素generator。

函數形式:from_generator(generator,output_types,output_shapes=None,args=None)

參數generator:一個可調用對象,它返回支持該iter()協議的對象 。如果args未指定,generator則不得參數; 否則它必須采取與有值一樣多的參數args。
參數output_types:tf.DType對應於由元素生成的元素的每個組件的對象的嵌套結構generator。
參數output_shapes:tf.TensorShape 對應於由元素生成的元素的每個組件的對象 的嵌套結構generator
參數args:tf.Tensor將被計算並將generator作為NumPy數組參數傳遞的對象元組。

具體例子

#定義一個生成器
def data_generator():
    dataset = np.array(range(9))
    for i in dataset:
        yield i

#接收生成器,並生產dataset數據結構
dataset = tf.data.Dataset.from_generator(data_generator, (tf.int32))

iterator = concat_dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:
   for i in range(3):
       print(sess.run(element))

以上代碼運行結果:0 1 2


####4.batch

batch可以將數據集的連續元素合成批次。

函數形式:batch(batch_size,drop_remainder=False)

參數batch_size:表示要在單個批次中合並的此數據集的連續元素個數。
參數drop_remainder:表示在少於batch_size元素的情況下是否應刪除最后一批 ; 默認是不刪除。

具體例子:

#創建一個Dataset對象
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])

'''合成批次'''
dataset=dataset.batch(3)

#創建一個迭代器
iterator = dataset.make_one_shot_iterator()

#get_next()函數可以幫助我們從迭代器中獲取元素
element = iterator.get_next()

#遍歷迭代器,獲取所有元素
with tf.Session() as sess:
   for i in range(9):
       print(sess.run(element))

以上代碼運行結果為:
[1 2 3]
[4 5 6]
[7 8 9]

即把目標對象合成3個批次,返回的對象是傳入Dataset對象。


####5.concatenate

concatenate可以將兩個Dataset對象進行合並或連接.

函數形式:concatenate(dataset)

參數dataset:表示需要傳入的dataset對象。

具體例子:

#創建dataset對象
dataset_a=tf.data.Dataset.from_tensor_slices([1,2,3])
dataset_b=tf.data.Dataset.from_tensor_slices([4,5,6])

#合並dataset
concat_dataset=dataset_a.concatenate(dataset_b)

iterator = concat_dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:
   for i in range(6):
       print(sess.run(element))

以上代碼運行結果:1 2 3 4 5 6


####6.filter

filter可以對傳入的dataset數據進行條件過濾.

函數形式:filter(predicate)

參數predicate:條件過濾函數

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])

#對dataset內的數據進行條件過濾
dataset=dataset.filter(lambda x:x>3)

iterator = dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:
    for i in range(6):
       print(sess.run(element))

以上代碼運行結果:4 5 6 7 8 9


####7.map

map可以將map_func函數映射到數據集

函數形式:flat_map(map_func,num_parallel_calls=None)

參數map_func:映射函數
參數num_parallel_calls:表示要並行處理的數字元素。如果未指定,將按順序處理元素。

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])

#進行map操作
dataset=dataset.map(lambda x:x+1)

iterator = dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:
   for i in range(6):
       print(sess.run(element))

以上代碼運行結果:2 3 4 5 6 7


####8.flat_map

flat_map可以將map_func函數映射到數據集(與map不同的是flat_map傳入的數據必須是一個dataset)。

函數形式:flat_map(map_func)

參數map_func:映射函數

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])

#進行flat_map操作
dataset=dataset.flat_map(lambda x:tf.data.Dataset.from_tensor_slices(x+[1]))

iterator = dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:
   for i in range(6):
       print(sess.run(element))

以上代碼運行結果:2 3 4 5 6 7


####9.make_one_shot_iterator

創建Iterator用於枚舉此數據集的元素。(可自動初始化)

函數形式:make_one_shot_iterator()

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])

iterator = dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:
   for i in range(6):
       print(sess.run(element))

####10.make_initializable_iterator

創建Iterator用於枚舉此數據集的元素。(使用此函數前需先進行迭代器的初始化操作)

函數形式:make_initializable_iterator(shared_name=None)

參數shared_name:(可選)如果非空,則返回的迭代器將在給定名稱下共享同一設備的多個會話(例如,使用遠程服務器時)

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])

iterator = dataset.make_initializable_iterator()

element = iterator.get_next()

with tf.Session() as sess:

   #對迭代器進行初始化操作
   sess.run(iterator.initializer)

   for i in range(5):
       print(sess.run(element))

####11.padded_batch

將數據集的連續元素組合到填充批次中,此轉換將輸入數據集的多個連續元素組合為單個元素。

函數形式:padded_batch(batch_size,padded_shapes,padding_values=None,drop_remainder=False)

參數batch_size:表示要在單個批次中合並的此數據集的連續元素數。
參數padded_shapes:嵌套結構tf.TensorShape或 tf.int64類似矢量張量的對象,表示在批處理之前應填充每個輸入元素的相應組件的形狀。任何未知的尺寸(例如,tf.Dimension(None)在一個tf.TensorShape或-1類似張量的物體中)將被填充到每個批次中該尺寸的最大尺寸。
參數padding_values:(可選)標量形狀的嵌套結構 tf.Tensor,表示用於各個組件的填充值。默認值0用於數字類型,空字符串用於字符串類型。
參數drop_remainder:(可選)一個tf.bool標量tf.Tensor,表示在少於batch_size元素的情況下是否應刪除最后一批 ; 默認行為是不刪除較小的批處理。

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])

dataset=dataset.padded_batch(2,padded_shapes=[])

iterator = dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:
   for i in range(6):
       print(sess.run(element))

以上代碼運行結果:
[1 2]
[3 4]


####12.repeat

重復此數據集count次數

函數形式:repeat(count=None)

參數count:(可選)表示數據集應重復的次數。默認行為(如果count是None或-1)是無限期重復的數據集。

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])

#無限次重復dataset數據集
dataset=dataset.repeat()

iterator = dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:
   for i in range(30,35):
       print(sess.run(element))

以上代碼運行結果: 1 2 3 4 5


####13.shard

將Dataset分割成num_shards個子數據集。這個函數在分布式訓練中非常有用,它允許每個設備讀取唯一子集。

函數形式:shard( num_shards,index)

參數num_shards:表示並行運行的分片數。
參數index:表示工人索引。


####14.shuffle

隨機混洗數據集的元素。

函數形式:shuffle(buffer_size,seed=None,reshuffle_each_iteration=None)

參數buffer_size:表示新數據集將從中采樣的數據集中的元素數。
參數seed:(可選)表示將用於創建分布的隨機種子。
參數reshuffle_each_iteration:(可選)一個布爾值,如果為true,則表示每次迭代時都應對數據集進行偽隨機重組。(默認為True。)

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])

#隨機混洗數據
dataset=dataset.shuffle(3)

iterator = dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:
   for i in range(30,35):
       print(sess.run(element))

以上代碼運行結果:3 2 4


####15.skip

生成一個跳過count元素的數據集。

函數形式:skip(count)

參數count:表示應跳過以形成新數據集的此數據集的元素數。如果count大於此數據集的大小,則新數據集將不包含任何元素。如果count 為-1,則跳過整個數據集。

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9])

#跳過前5個元素
dataset=dataset.skip(5)

iterator = dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:
   for i in range(30,35):
       print(sess.run(element))

以上代碼運行結果: 6 7 8


####16.take

提取前count個元素形成性數據集

函數形式:take(count)

參數count:表示應該用於形成新數據集的此數據集的元素數。如果count為-1,或者count大於此數據集的大小,則新數據集將包含此數據集的所有元素。

具體例子

dataset = tf.data.Dataset.from_tensor_slices([1,2,2,3,4,5,6,7,8,9])

#提取前5個元素形成新數據
dataset=dataset.take(5)

iterator = dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:
   for i in range(30,35):
       print(sess.run(element))

以上代碼運行結果: 1 2 2


####17.zip

將給定數據集壓縮在一起

函數形式:zip(datasets)

參數datesets:數據集的嵌套結構。

具體例子

dataset_a=tf.data.Dataset.from_tensor_slices([1,2,3])

dataset_b=tf.data.Dataset.from_tensor_slices([2,6,8])

zip_dataset=tf.data.Dataset.zip((dataset_a,dataset_b))

iterator = dataset.make_one_shot_iterator()

element = iterator.get_next()

with tf.Session() as sess:
   for i in range(30,35):
       print(sess.run(element))

以上代碼運行結果:
(1, 2)
(2, 6)
(3, 8)

到這里Dataset中大部分方法 都在這里做了初步的解釋,當然這些方法的配合使用才能夠在建模過程中發揮大作用。


更多信息可查看tensorflow官方文檔 https://www.tensorflow.org/api_docs/python/tf/data/Dataset


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM