tf.train.batch的偶爾亂序問題
tf.train.batch的偶爾亂序問題
- 我們在通過tf.Reader讀取文件后,都需要用batch函數將讀取的數據根據預先設定的batch_size打包為一個個獨立的batch方便我們進行學習。
- 常用的batch函數有tf.train.batch和tf.train.shuffle_batch函數。前者是將數據從前往后讀取並順序打包,后者則要進行亂序處理————即將讀取的數據進行亂序后在組成批次。
- 訓練時我往往都是使用shuffle_batch函數,但是這次我在驗證集上預調好模型並freeze模型后我需要在測試集上進行測試。此時我需要將數據的標簽和inference后的結果進行一一對應。 此時數據出現的順序是十分重要的,這保證我們的產品在上線前的測試集中能准確get到每個數據和inference后結果的差距 而在驗證集中我們不太關心數據原有的標簽和inference后的真實值,我們往往只是需要讓這兩個數據一一對應,關於數據出現的順序我們並不關心。
- 此時我們一般使用tf.train.batch函數將tf.Reader讀取的值進行順序打包即可。
然而tf.train.batch函數往往會有偶爾亂序的情況
- 我們將csv文件中每個數據樣本從上往下依次進行標號,我們在使用tf.trian.batch函數依次進行讀取,如果我們讀取的數據編號亂序了,則表明tf.train.batch函數有偶爾亂序的狀況。
源程序文件下載
test_tf_train_batch.csv
import tensorflow as tf
BATCH_SIZE = 400
NUM_THREADS = 2
MAX_NUM = 500
def read_data(file_queue):
reader = tf.TextLineReader(skip_header_lines=1)
key, value = reader.read(file_queue)
defaults = [[0], [0.], [0.]]
NUM, C, Tensile = tf.decode_csv(value, defaults)
vertor_example = tf.stack([C])
vertor_label = tf.stack([Tensile])
vertor_num = tf.stack([NUM])
return vertor_example, vertor_label, vertor_num
def create_pipeline(filename, batch_size, num_threads):
file_queue = tf.train.string_input_producer([filename]) # 設置文件名隊列
example, label, no = read_data(file_queue) # 讀取數據和標簽
example_batch, label_batch, no_batch = tf.train.batch(
[example, label, no], batch_size=batch_size, num_threads=num_threads, capacity=MAX_NUM)
return example_batch, label_batch, no_batch
x_train_batch, y_train_batch, no_train_batch = create_pipeline('test_tf_train_batch.csv', batch_size=BATCH_SIZE,
num_threads=NUM_THREADS)
init_op = tf.global_variables_initializer()
local_init_op = tf.local_variables_initializer()
with tf.Session() as sess:
sess.run(local_init_op)
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
example, label, num = sess.run([x_train_batch, y_train_batch, no_train_batch])
print(example)
print(label)
print(num)
coord.request_stop()
coord.join(threads)
實驗結果
我們將csv文件中的真實Tensile值放在第一列,將使用tf.train.batch函數得到的Tensile和no分別放在第二列和第三列
| TureTensile | FalseTensile | NO |
|---|---|---|
| 0.830357143 | [ 0.52678573] | [ 66] |
| 0.526785714 | [ 0.83035713] | [ 65] |
| 0.553571429 | [ 0.4375 ] | [ 68] |
| 0.4375 | [ 0.5535714 ] | [ 67] |
| 0.517857143 | [ 0.33035713] | [ 70] |
| 0.330357143 | [ 0.51785713] | [ 69] |
| 0.482142857 | [ 0.6785714 ] | [ 72] |
| 0.678571429 | [ 0.48214287] | [ 71] |
| 0.419642857 | [ 0.02678571] | [ 74] |
| 0.026785714 | [ 0.41964287] | [ 73] |
| 0.401785714 | [ 0.4017857 ] | [ 75] |
解決方案
- 將測試集中所有樣本數據加NO順序標簽列




