Tensorflow datasets.shuffle repeat batch方法


機器學習中數據讀取是很重要的一個環節,TensorFlow也提供了很多實用的方法,為了避免以后時間久了又忘記,所以寫下筆記以備日后查看。

最普通的正常情況

首先我們看看最普通的情況:

# 創建0-10的數據集,每個batch取個數。
dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(2):
        value = sess.run(next_element)
        print(value)

輸出結果

[0 1 2 3 4 5]
[6 7 8 9]

由結果我們可以知道TensorFlow能很好地幫我們自動處理最后一個batch的數據。

datasets.batch(batch_size)與迭代次數的關系

但是如果上面for循環次數超過2會怎么樣呢?也就是說如果 循環次數*批數量 > 數據集數量 會怎么樣?我們試試看:

dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    >>==for i in range(3):==<<
        value = sess.run(next_element)
        print(value)

輸出結果

[0 1 2 3 4 5]
[6 7 8 9]
---------------------------------------------------------------------------
OutOfRangeError                           Traceback (most recent call last)
D:\Continuum\anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
   1277     try:
   
  ...
  ...省略若干信息...
  ...
  
OutOfRangeError (see above for traceback): End of sequence
	 [[Node: IteratorGetNext_64 = IteratorGetNext[output_shapes=[[?]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_28)]]

可以知道超過范圍了,所以報錯了。

datasets.repeat()

為了解決上述問題,repeat方法登場。還是直接看例子吧:

dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(4):
        value = sess.run(next_element)
        print(value)

輸出結果

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]

可以知道repeat其實就是將數據集重復了指定次數,上面代碼將數據集重復了2次,所以這次即使for循環次數是4也依舊能正常讀取數據,並且都能完整把數據讀取出來。同理,如果把for循環次數設置為大於4,那么也還是會報錯,這么一來,我每次還得算repeat的次數,豈不是很心累?所以更簡便的辦法就是對repeat方法不設置重復次數,效果見如下:

dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(6):
        value = sess.run(next_element)
        print(value)

輸出結果:

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]

此時無論for循環多少次都不怕啦~~

datasets.shuffle(buffer_size)

仔細看可以知道上面所有輸出結果都是有序的,這在機器學習中用來訓練模型是浪費資源且沒有意義的,所以我們需要將數據打亂,這樣每批次訓練的時候所用到的數據集是不一樣的,這樣啊可以提高模型訓練效果。

另外shuffle前需要設置buffer_size:

  • 不設置會報錯,
  • buffer_size=1:不打亂順序,既保持原序
  • buffer_size越大,打亂程度越大,演示效果見如下代碼:
dataset = tf.data.Dataset.range(10).shuffle(2).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(4):
        value = sess.run(next_element)
        print(value)

輸出結果:

[1 0 2 4 3 5]
[7 8 9 6]
[1 2 3 4 0 6]
[7 8 9 5]

注意:shuffle的順序很重要,一般建議是最開始執行shuffle操作,因為如果是先執行batch操作的話,那么此時就只是對batch進行shuffle,而batch里面的數據順序依舊是有序的,那么隨機程度會減弱。不信你看:

dataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(4):
        value = sess.run(next_element)
        print(value)

輸出結果:

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]



MARSGGBO原創





2018-8-5




免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM