https://www.jb51.net/article/178976.htm
直接看代碼例子,有詳細注釋!!
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
|
import
tensorflow as tf
import
numpy as np
d
=
np.arange(
0
,
60
).reshape([
6
,
10
])
# 將array轉化為tensor
data
=
tf.data.Dataset.from_tensor_slices(d)
# 從data數據集中按順序抽取buffer_size個樣本放在buffer中,然后打亂buffer中的樣本
# buffer中樣本個數不足buffer_size,繼續從data數據集中安順序填充至buffer_size,
# 此時會再次打亂
data
=
data.shuffle(buffer_size
=
3
)
# 每次從buffer中抽取4個樣本
data
=
data.batch(
4
)
# 將data數據集重復,其實就是2個epoch數據集
data
=
data.repeat(
2
)
# 構造獲取數據的迭代器
iters
=
data.make_one_shot_iterator()
# 每次從迭代器中獲取一批數據
batch
=
iters.get_next()
sess
=
tf.Session()
sess.run(batch)
# 數據集完成遍歷完之后,繼續抽取的話會報錯:OutOfRangeError
|
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
In [
21
]: d
Out[
21
]:
array([[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
],
[
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
],
[
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
],
[
30
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
],
[
40
,
41
,
42
,
43
,
44
,
45
,
46
,
47
,
48
,
49
],
[
50
,
51
,
52
,
53
,
54
,
55
,
56
,
57
,
58
,
59
]])
In [
22
]: sess.run(batch)
Out[
22
]:
array([[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
],
[
30
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
],
[
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
],
[
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
]])
In [
23
]: sess.run(batch)
Out[
23
]:
array([[
40
,
41
,
42
,
43
,
44
,
45
,
46
,
47
,
48
,
49
],
[
50
,
51
,
52
,
53
,
54
,
55
,
56
,
57
,
58
,
59
]])
|
從輸出結果可以看出:
shuffle是按順序將數據放入buffer里面的;
當repeat函數在shuffle之后的話,是將一個epoch的數據集抽取完畢,再進行下一個epoch的。
那么,當repeat函數在shuffle之前會怎么樣呢?如下:
|
1
2
3
4
5
|
data
=
data.repeat(
2
)
data
=
data.shuffle(buffer_size
=
3
)
data
=
data.batch(
4
)
|
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
In [
25
]: sess.run(batch)
Out[
25
]:
array([[
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
],
[
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
],
[
40
,
41
,
42
,
43
,
44
,
45
,
46
,
47
,
48
,
49
]])
In [
26
]: sess.run(batch)
Out[
26
]:
array([[
50
,
51
,
52
,
53
,
54
,
55
,
56
,
57
,
58
,
59
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
],
[
30
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
],
[
30
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
]])
In [
27
]: sess.run(batch)
Out[
27
]:
array([[
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
],
[
50
,
51
,
52
,
53
,
54
,
55
,
56
,
57
,
58
,
59
],
[
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
],
[
40
,
41
,
42
,
43
,
44
,
45
,
46
,
47
,
48
,
49
]])
|
可以看出,其實它就是先將數據集復制一遍,然后把兩個epoch當成同一個新的數據集,一直shuffle和batch下去。
