1.簡介
將數據划分成若干批次的數據,可以使用tf.train或者tf.data.Dataset中的方法。
1.1 tf.train
tf.train.slice_input_producer(tensor_list,
shuffle=True,
seed=None,
capacity=
32
)
tf.train.batch(tensors,
batch_size,
num_threads=1,
capacity=32
,
allow_smaller_final_batch=False
)
參數說明:
shuffle:為True時進行數據清洗
allow_smaller_final_batch:為True時將小於batch_size的批次值輸出
-------------------------------------------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------------------------------------------
1.2 tf.data.Dataset
tf.data.Dataset是一個類,可以使用以下方法:
from_tensor_slices(tensors
)
batch(batch_size,
drop_remainder=False
)
shuffle(buffer_size,
seed=None,
reshuffle_each_iteration=None
)
repeat(count=None
)
make_one_shot_iterator() / get_next()
注:make_one_shot_iterator() / get_next()用於Dataset數據的迭代器
參數說明:
tensors:可以是列表、字典、元組等類型
drop_remainder:為False時表示不保留小於batch_size的批次,否則刪除
buffer_size:數據清洗時使用的buffer大小
count:對應為epoch個數,為None時表示數據序列無限延續
2.示例
2.1 使用tf.train.slice_input_producer和tf.train.batch
1 import tensorflow as tf 2 import numpy as np 3 import math 4 5 # 生成樣例數據集 6 def generate_data(): 7 num = 15 8 labels = np.asarray(range(num)) 9 images = np.random.random([num, 5, 5, 3]) 10 return images, labels 11 12 # 打印樣例信息 13 images, labels = generate_data() 14 print('images.shape={0}, labels.shape={1}'.format(images.shape, labels.shape)) 15 16 # 定義周期、批次、數據總量和遍歷一次所有數據所需的迭代次數 17 n_epochs = 3 18 batch_size = 6 19 train_nums = 15 20 iterations = math.ceil(train_nums/batch_size) 21 22 # 使用tf.train.slice_input_producer將所有數據放入隊列,使用tf.train.batch划分隊列中的數據 23 input_queue = tf.train.slice_input_producer([images, labels], shuffle=False) 24 image_batch, label_batch = tf.train.batch(input_queue, batch_size=batch_size, num_threads=1, capacity=32) 25 print('image_batch.shape={0}, label_batch.shape={1}'.format(image_batch.shape, label_batch.shape)) 26 27 28 with tf.Session() as sess: 29 tf.global_variables_initializer().run() 30 # 啟動隊列線程 31 coord = tf.train.Coordinator() 32 threads = tf.train.start_queue_runners(sess, coord) 33 # 打印信息 34 for epoch in range(n_epochs): 35 for iteration in range(iterations): 36 cu_image_batch, cu_label_batch = sess.run([image_batch, label_batch]) 37 print('The {0} epoch, the {1} iteration, current batch is {2}'.format(epoch+1,iteration+1,cu_label_batch)) 38 # 接收線程 39 coord.request_stop() 40 coord.join(threads) 41 42 43 # 打印結果如下 44 images.shape=(15, 5, 5, 3), labels.shape=(15,) 45 image_batch.shape=(6, 5, 5, 3), label_batch.shape=(6,) 46 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5] 47 The 1 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11] 48 The 1 epoch, the 3 iteration, current batch is [12 13 14 0 1 2] 49 The 2 epoch, the 1 iteration, current batch is [3 4 5 6 7 8] 50 The 2 epoch, the 2 iteration, current batch is [ 9 10 11 12 13 14] 51 The 2 epoch, the 3 iteration, current batch is [0 1 2 3 4 5] 52 The 3 epoch, the 1 iteration, current batch is [ 6 7 8 9 10 11] 53 The 3 epoch, the 2 iteration, current batch is [12 13 14 0 1 2] 54 The 3 epoch, the 3 iteration, current batch is [3 4 5 6 7 8]
如果tf.train.slice_input_producer(shuffle=True),輸出為亂序,結果如下:
1 images.shape=(15, 5, 5, 3), labels.shape=(15,) 2 image_batch.shape=(6, 5, 5, 3), label_batch.shape=(6,) 3 The 1 epoch, the 1 iteration, current batch is [ 2 5 8 11 3 10] 4 The 1 epoch, the 2 iteration, current batch is [ 9 12 7 1 14 13] 5 The 1 epoch, the 3 iteration, current batch is [0 6 4 2 3 6] 6 The 2 epoch, the 1 iteration, current batch is [11 10 12 14 13 5] 7 The 2 epoch, the 2 iteration, current batch is [8 1 0 9 4 7] 8 The 2 epoch, the 3 iteration, current batch is [10 13 1 4 12 3] 9 The 3 epoch, the 1 iteration, current batch is [ 2 8 5 9 14 7] 10 The 3 epoch, the 2 iteration, current batch is [ 0 11 6 1 14 9] 11 The 3 epoch, the 3 iteration, current batch is [11 6 12 7 0 13]
如果tf.train.batch(allow_smaller_final_batch=True),則會返回不足批次數目的數據,結果如下:
1 images.shape=(15, 5, 5, 3), labels.shape=(15,) 2 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5] 3 The 1 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11] 4 The 1 epoch, the 3 iteration, current batch is [12 13 14] 5 The 2 epoch, the 1 iteration, current batch is [0 1 2 3 4 5] 6 The 2 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11] 7 The 2 epoch, the 3 iteration, current batch is [12 13 14] 8 The 3 epoch, the 1 iteration, current batch is [0 1 2 3 4 5] 9 The 3 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11] 10 The 3 epoch, the 3 iteration, current batch is [12 13 14]
2.2 使用tf.data.Dataset類
1 import tensorflow as tf 2 import numpy as np 3 import math 4 5 # 生成樣例數據集 6 def generate_data(): 7 num = 15 8 labels = np.asarray(range(num)) 9 images = np.random.random([num, 5, 5, 3]) 10 return images, labels 11 # 打印樣例信息 12 images, labels = generate_data() 13 print('images.shape={0}, labels.shape={1}'.format(images.shape, labels.shape)) 14 15 # 定義周期、批次、數據總數、遍歷一次所有數據需的迭代次數 16 n_epochs = 3 17 batch_size = 6 18 train_nums = 15 19 iterations = math.ceil(train_nums/batch_size) 20 21 # 使用from_tensor_slices將數據放入隊列,使用batch和repeat划分數據批次,且讓數據序列無限延續 22 dataset = tf.data.Dataset.from_tensor_slices((images, labels)) 23 dataset = dataset.batch(batch_size).repeat() 24 25 # 使用生成器make_one_shot_iterator和get_next取數據 26 iterator = dataset.make_one_shot_iterator() 27 next_iterator = iterator.get_next() 28 29 with tf.Session() as sess: 30 for epoch in range(n_epochs): 31 for iteration in range(iterations): 32 cu_image_batch, cu_label_batch = sess.run(next_iterator) 33 print('The {0} epoch, the {1} iteration, current batch is {2}'.format(epoch+1,iteration+1,cu_label_batch)) 34 35 36 # 結果如下: 37 images.shape=(15, 5, 5, 3), labels.shape=(15,) 38 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5] 39 The 1 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11] 40 The 1 epoch, the 3 iteration, current batch is [12 13 14] 41 The 2 epoch, the 1 iteration, current batch is [0 1 2 3 4 5] 42 The 2 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11] 43 The 2 epoch, the 3 iteration, current batch is [12 13 14] 44 The 3 epoch, the 1 iteration, current batch is [0 1 2 3 4 5] 45 The 3 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11] 46 The 3 epoch, the 3 iteration, current batch is [12 13 14]
使用shuffle(),第23行修改為dataset = dataset.shuffle(100).batch(batch_size).repeat(),結果如下:
1 images.shape=(15, 5, 5, 3), labels.shape=(15,) 2 The 1 epoch, the 1 iteration, current batch is [ 7 4 10 8 3 11] 3 The 1 epoch, the 2 iteration, current batch is [ 0 2 12 13 14 5] 4 The 1 epoch, the 3 iteration, current batch is [6 9 1] 5 The 2 epoch, the 1 iteration, current batch is [ 6 14 7 9 3 8] 6 The 2 epoch, the 2 iteration, current batch is [13 5 12 1 11 2] 7 The 2 epoch, the 3 iteration, current batch is [ 0 4 10] 8 The 3 epoch, the 1 iteration, current batch is [10 8 13 12 3 14] 9 The 3 epoch, the 2 iteration, current batch is [ 6 9 2 5 1 11] 10 The 3 epoch, the 3 iteration, current batch is [0 4 7]
!!!