tensorflow讀取數據集生成batch——tf.data.Dataset.from_tensor_slices


import tensorflow as tf
x1 = tf.constant([[1.0, 2., 3.], [4., 5., 6.],[7., 8.,9.], [10., 11.,12.]])
y1 = tf.constant([[0.5, 1.5, 2.5], [3.5, 4.5, 5.5],[6.5, 7.5, 8.5], [9.5, 10.5, 11.5]])
# 創建dataset
dataset = tf.data.Dataset.from_tensor_slices((x1, y1))
dataset = dataset.shuffle(100).batch(3).repeat()
# iterator = dataset.make_one_shot_iterator()#對應不需要初始化,不能更改數據源
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    sess.run(iterator.initializer)
    for i in range(4):
        value = sess.run(next_element)
        sess.run(next_element)
        print(value)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM