tensorflow读取数据集生成batch——tf.data.Dataset.from_tensor_slices


import tensorflow as tf
x1 = tf.constant([[1.0, 2., 3.], [4., 5., 6.],[7., 8.,9.], [10., 11.,12.]])
y1 = tf.constant([[0.5, 1.5, 2.5], [3.5, 4.5, 5.5],[6.5, 7.5, 8.5], [9.5, 10.5, 11.5]])
# 创建dataset
dataset = tf.data.Dataset.from_tensor_slices((x1, y1))
dataset = dataset.shuffle(100).batch(3).repeat()
# iterator = dataset.make_one_shot_iterator()#对应不需要初始化,不能更改数据源
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    sess.run(iterator.initializer)
    for i in range(4):
        value = sess.run(next_element)
        sess.run(next_element)
        print(value)

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM