TensorFlow dataset.shuffle、batch、repeat的使用詳解


https://www.jb51.net/article/178976.htm

直接看代碼例子,有詳細注釋!!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import tensorflow as tf
import numpy as np
 
 
d = np.arange( 0 , 60 ).reshape([ 6 , 10 ])
 
# 將array轉化為tensor
data = tf.data.Dataset.from_tensor_slices(d)
 
# 從data數據集中按順序抽取buffer_size個樣本放在buffer中,然后打亂buffer中的樣本
# buffer中樣本個數不足buffer_size,繼續從data數據集中安順序填充至buffer_size,
# 此時會再次打亂
data = data.shuffle(buffer_size = 3 )
 
# 每次從buffer中抽取4個樣本
data = data.batch( 4 )
 
# 將data數據集重復,其實就是2個epoch數據集
data = data.repeat( 2 )
 
# 構造獲取數據的迭代器
iters = data.make_one_shot_iterator()
 
# 每次從迭代器中獲取一批數據
batch = iters.get_next()
 
sess = tf.Session()
 
sess.run(batch)
# 數據集完成遍歷完之后,繼續抽取的話會報錯:OutOfRangeError
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
In [ 21 ]: d
Out[ 21 ]:
array([[ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ],
   [ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ],
   [ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ],
   [ 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 ],
   [ 40 , 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 ],
   [ 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 ]])
In [ 22 ]: sess.run(batch)
Out[ 22 ]:
array([[ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ],
   [ 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 ],
   [ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ],
   [ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ]])
 
In [ 23 ]: sess.run(batch)
Out[ 23 ]:
array([[ 40 , 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 ],
   [ 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 ]])

從輸出結果可以看出:

shuffle是按順序將數據放入buffer里面的;

當repeat函數在shuffle之后的話,是將一個epoch的數據集抽取完畢,再進行下一個epoch的。

那么,當repeat函數在shuffle之前會怎么樣呢?如下:

1
2
3
4
5
data = data.repeat( 2 )
 
data = data.shuffle(buffer_size = 3 )
 
data = data.batch( 4 )
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
In [ 25 ]: sess.run(batch)
Out[ 25 ]:
array([[ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ],
   [ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ],
   [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ],
   [ 40 , 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 ]])
 
In [ 26 ]: sess.run(batch)
Out[ 26 ]:
array([[ 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 ],
   [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ],
   [ 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 ],
   [ 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 ]])
 
In [ 27 ]: sess.run(batch)
Out[ 27 ]:
array([[ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ],
   [ 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 ],
   [ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ],
   [ 40 , 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 ]])

可以看出,其實它就是先將數據集復制一遍,然后把兩個epoch當成同一個新的數據集,一直shuffle和batch下去。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM