前言
一直以來都是用 tensorflow 框架實現深度學習算法和實驗,在網絡訓練時有一個重要的問題就是訓練數據的讀取。tensorflow 支持流水線並行讀取數據,這種方式將數據的讀取和網絡訓練並行,數據讀取效率和將所有數據載入內存后進行存取相當,卻又不會增加內存開銷,是很值得推薦的一種方式。這篇筆記就是總結一下自己在實際應用中的並行數據讀取,留個備份,隨時學習。
主要參考了 Google HDRnet 代碼:https://github.com/mgharbi/hdrnet,CycleGAN 代碼:https://github.com/vanhuyz/CycleGAN-TensorFlow
數據讀取
HDRnet工程里的 data_pipeline.py 文件提供了非常清晰的流水線讀取數據示例,在官方代碼的基礎上,可以很輕松地針對自己的應用實現一套數據讀取接口,假設我們的訓練數據存儲在目錄 training_data/input 和 training_data/output,input 存儲網絡訓練輸入,output 存儲網絡目標輸出,一對訓練樣本的輸入和目標輸出名稱相同,均為二進制文件 *.dat,以下面代碼為示例展示如何實現流水線並行數據讀取:
def data_generator(params, data_path): filelist = os.listdir(data_path) # 獲取訓練目錄下的文件名列表 if params.shuffle: random.shuffle(filelist) # 隨機打亂訓練數據 input_files = [os.path.join(data_path, 'input', f) for f in filelist if f.endswith('.dat')] # 生成輸入數據文件名列表 output_files = [os.path.join(data_path, 'output', f) for f in filelist if f.endswith('.dat')] # 生成目標輸出文件名列表
# 基於給定的文件名列表,創建先入先出的文件名隊列,輸入可以是多個文件名列表,輸出對應的對個文件名隊列 input_queue, output_queue = tf.train.slice_input_producer( [input_files, output_files], shuffle=params.shuffle, seed=params.seed, num_epochs=params.num_epochs) input_reader = tf.read_file(input_queue) # 創建 reader,讀取輸入數據 output_reader = tf.read_file(output_queue) # 創建 reader,讀取目標輸出
# 根據文件類型的不同解析數據,如果文件是圖像,可以使用 tf.image.decode_jpeg 等函數解析 if os.path.splitext(input_files[0])[-1] == '.jpg': input = tf.image.decode_jpeg(input_reader, channels=3) else: input = tf.decode_raw(input_reader, data_type=tf.uint16) # 如果是二進制信息存儲,則可以使用 tf.decode_raw 函數解析 input = tf.reshape(input, [params.height, params.width, params.channel]) # 將數據 reshape 為正確的形狀,此處以圖像 (height, width, channel) 為例 if os.path.splitext(output_files[0])[-1] == '.jpg': output = tf.image.decode_jpeg(output_reader, channels=3) else: output = tf.decode_raw(output_reader, data_type=tf.uint16) input = tf.reshape(input, [params.height, params.width, params.channel])
# 上面讀取了單個輸入和對應的目標輸出,網絡訓練時如需數據增廣,可以在讀取單個訓練對之后,使用函數對數據進行處理,擴大訓練集 input, output = augment_data(input, output) samples = {} # 將增廣后的一對訓練數據組織為字典的形式,便於后面組織成 batch samples['input'] = input samples['output'] = output if param.shuffle: # 創建批樣例訓練數據 samples = tf.train.shuffle_batch( sample, batch_size=params.batch_size, num_threads=params.nthreads, capacity=params.capacity, min_after_dequeue=params.min_after_dequeue) else: samples = tf.train.batch( sample, batch_size=params.batch_size, num_threads=params.nthreads, capacity=params.capacity) return samples # 返回一個 batch 的訓練數據
代碼中具體函數的接口可以通過 tensorflow 的文檔查清。以上,只是聲明了多線程的文件讀取操作,並不會真正的讀取數據,為了在會話執行時順利地獲取輸入數據,需要使用 tf.train.start_queue_runners 來啟動執行入隊列操作的所有線程,具體過程包括:文件名入隊到文件名隊列,樣例入隊到樣例隊列。示例代碼如下:
params.shuffle = true params.seed = 1234 params.height = 224 params.width = 224 params.channel = 3 training_path = 'dir/to/training/data' training_samples = data_generator(params, training_path) batch_inputs = training_samples['input'] batch_outputs = training_sample['output']
# 網絡計算圖創建
conv_1 = Conv2D(batch_inputs, ...)
...
conv_n = Conv2D(conv_n-1, ...)
output = tf.sigmoid(conv_n)
loss = tf.reduce_mean(tf.squared_difference(output, batch_outputs))
train_op = tf.minimize(loss,...)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess = sess)
sess.run(train_op)
...
上面的代碼中,輸入輸出各只有一張圖像,展示了如何實現流水線讀取,以及如何使用讀取出的數據。當輸入或者輸出包含多個文件時,例如,輸入是圖像和其語義分割圖,可以在 data_generator 函數中,增加對語義分割圖的讀取,相對應的,多了 seg_files、seg_queue、seg_reader、seg_map 以及最后的 samples['seg_map'] = seg_map;
同樣,當輸入數據是其它格式時,只需要根據對應的格式修改數據讀取的代碼接,例如 CycleGAN 中,訓練數據存儲為 tfrecord 格式,需要修改的其實就是對文件的讀取部分。
我們都知道,tensorflow 在創建網絡計算圖時,通常需要為網絡輸入和目標輸出先聲明 placeholder,但是上面的第二段示例代碼則是直接使用數據讀取的輸出構建網絡計算圖,是不是說采用這種方式就不能采用常見方法那樣,先定義 placeholder,再在網絡訓練中使用 feed_dict 填充數據呢?答案是可以的,方法也和通常的做法沒有太大區別,示例如下:
x = tf.placeholder(...)
y = tf.palceholder(...) conv_1 = Conv2D(y, ...) ...
loss = tf.reduce_mean(tf.squared_difference(net_y, y))
train_op = tf.minimize(loss, ...)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess=sess)
samples = data_generator(params, training_path)
sess.run(train_op, feed_dict={x: samples['input'], y: samples['output']})
和第一種方法的區別是 data_generator 是在會話 sess 中調用,而不是在構建網絡計算圖時調用;
需要注意的是,上面的方式容錯性比較差,主要是因為采用多線程方式讀取數據,隊列操作后台線程的生命周期無管理機制,線程出現異常會導致程序崩潰,比較常見的異常是文件名隊列或者樣例隊列越界拋出的 tf.errors.OutOfRangeError。為了處理這種異常,HDRnet、CycleGAN 工程代碼中都使用 tf.train.Coordinator 創建了管理多線程聲明周期的協調器,其工作原理是通過監控 tensorflow 所有后台線程,當有線程出現異常時,協調器的 should_stop 成員方法返回 True,循環結束,然后會話執行協調器的 request_stop 方法,請求所有線程安全退出。一套完整的示例代碼如下:
params.shuffle = true
params.seed = 1234
params.height = 224
params.width = 224
params.channel = 3
training_path = 'dir/to/training/data'
training_samples = data_generator(params, training_path)
batch_inputs = training_samples['input']
batch_outputs = training_sample['output']
# 網絡計算圖創建
conv_1 = Conv2D(batch_inputs, ...)
...
conv_n = Conv2D(conv_n-1, ...)
output = tf.sigmoid(conv_n)
loss = tf.reduce_mean(tf.squared_difference(output, batch_outputs))
train_op = tf.minimize(loss,...)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
sess.run(train_op)
except KeyboardInterrupt: # 響應 Ctrl+C 停止訓練
coord.request_stop()
except Exception as e: # 后台線程出現異常
coord.request_stop(e)
finally: # 這一步總會執行
save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) # 保存 checkpoint
coord.request_stop()
coord.join(threads)
總結
以上,介紹 tensorflow 中如何使用多線程並行讀取數據,如何在訓練中使用讀取的數據,以及如何對多線程進行監視,提升網絡訓練的容錯性。分享給大家,也給自己學習。