dataset中shuffle()、repeat()、batch()用法


import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)

dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(2) # 將數據打亂,數值越大,混亂程度越大
dataset = dataset.batch(4) # 按照順序取出4行數據,最后一次輸出可能小於batch
dataset = dataset.repeat() # 數據集重復了指定次數
# repeat()在batch操作輸出完畢后再執行,若在之前,相當於先把整個數據集復制兩次
#為了配合輸出次數,一般默認repeat()空

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
for i in range(6):
value = sess.run(el)
print(value)

 

 

更多的不同和進階可以參考這個博客

 https://blog.csdn.net/qq_16234613/article/details/81703228


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM