機器學習中數據讀取是很重要的一個環節,TensorFlow也提供了很多實用的方法,為了避免以后時間久了又忘記,所以寫下筆記以備日后查看。
最普通的正常情況
首先我們看看最普通的情況:
# 創建0-10的數據集,每個batch取個數。
dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(2):
value = sess.run(next_element)
print(value)
輸出結果
[0 1 2 3 4 5]
[6 7 8 9]
由結果我們可以知道TensorFlow能很好地幫我們自動處理最后一個batch的數據。
datasets.batch(batch_size)與迭代次數的關系
但是如果上面for循環次數超過2會怎么樣呢?也就是說如果 循環次數*批數量 > 數據集數量 會怎么樣?我們試試看:
dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
>>==for i in range(3):==<<
value = sess.run(next_element)
print(value)
輸出結果
[0 1 2 3 4 5]
[6 7 8 9]
---------------------------------------------------------------------------
OutOfRangeError Traceback (most recent call last)
D:\Continuum\anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
1277 try:
...
...省略若干信息...
...
OutOfRangeError (see above for traceback): End of sequence
[[Node: IteratorGetNext_64 = IteratorGetNext[output_shapes=[[?]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_28)]]
可以知道超過范圍了,所以報錯了。
datasets.repeat()
為了解決上述問題,repeat方法登場。還是直接看例子吧:
dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(4):
value = sess.run(next_element)
print(value)
輸出結果
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
可以知道repeat其實就是將數據集重復了指定次數,上面代碼將數據集重復了2次,所以這次即使for循環次數是4也依舊能正常讀取數據,並且都能完整把數據讀取出來。同理,如果把for循環次數設置為大於4,那么也還是會報錯,這么一來,我每次還得算repeat的次數,豈不是很心累?所以更簡便的辦法就是對repeat方法不設置重復次數,效果見如下:
dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(6):
value = sess.run(next_element)
print(value)
輸出結果:
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
此時無論for循環多少次都不怕啦~~
datasets.shuffle(buffer_size)
仔細看可以知道上面所有輸出結果都是有序的,這在機器學習中用來訓練模型是浪費資源且沒有意義的,所以我們需要將數據打亂,這樣每批次訓練的時候所用到的數據集是不一樣的,這樣啊可以提高模型訓練效果。
另外shuffle前需要設置buffer_size:
- 不設置會報錯,
- buffer_size=1:不打亂順序,既保持原序
- buffer_size越大,打亂程度越大,演示效果見如下代碼:
dataset = tf.data.Dataset.range(10).shuffle(2).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(4):
value = sess.run(next_element)
print(value)
輸出結果:
[1 0 2 4 3 5]
[7 8 9 6]
[1 2 3 4 0 6]
[7 8 9 5]
注意:shuffle的順序很重要,一般建議是最開始執行shuffle操作,因為如果是先執行batch操作的話,那么此時就只是對batch進行shuffle,而batch里面的數據順序依舊是有序的,那么隨機程度會減弱。不信你看:
dataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(4):
value = sess.run(next_element)
print(value)
輸出結果:
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
