tf.train.slice_input_producer()


tf.train.slice_input_producer處理的是來源tensor的數據

轉載自:https://blog.csdn.net/dcrmg/article/details/79776876 里面有詳細參數解釋

官方說明

簡單使用

import tensorflow as tf
 
images = ['img1', 'img2', 'img3', 'img4', 'img5']
labels= [1,2,3,4,5]
 
epoch_num=8
 
f = tf.train.slice_input_producer([images, labels],num_epochs=None,shuffle=True)
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(epoch_num):
        k = sess.run(f)
        print (i,k)
 
    coord.request_stop()
    coord.join(threads)

運行結果:

用tf.data.Dataset.from_tensor_slices調用,之前的會被拋棄,用法:https://blog.csdn.net/qq_32458499/article/details/78856530

結合批處理

import tensorflow as tf
import numpy as np
 
# 樣本個數
sample_num=5
# 設置迭代次數
epoch_num = 2
# 設置一個批次中包含樣本個數
batch_size = 3
# 計算每一輪epoch中含有的batch個數
batch_total = int(sample_num/batch_size)+1
 
# 生成4個數據和標簽
def generate_data(sample_num=sample_num):
    labels = np.asarray(range(0, sample_num))
    images = np.random.random([sample_num, 224, 224, 3])
    print('image size {},label size :{}'.format(images.shape, labels.shape))
 
    return images,labels
 
def get_batch_data(batch_size=batch_size):
    images, label = generate_data()
    # 數據類型轉換為tf.float32
    images = tf.cast(images, tf.float32)
    label = tf.cast(label, tf.int32)
 
    #從tensor列表中按順序或隨機抽取一個tensor
    input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
 
    image_batch, label_batch = tf.train.batch(input_queue, batch_size=batch_size, num_threads=1, capacity=64)
    return image_batch, label_batch
 
image_batch, label_batch = get_batch_data(batch_size=batch_size)
 
with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess, coord)
    try:
        for i in range(epoch_num):  # 每一輪迭代
            print ('************')
            for j in range(batch_total): #每一個batch
                print ('--------')
                # 獲取每一個batch中batch_size個樣本和標簽
                image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
                # for k in
                print(image_batch_v.shape, label_batch_v)
    except tf.errors.OutOfRangeError:
        print("done")
    finally:
        coord.request_stop()
    coord.join(threads)

運行結果:


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM