TensorFlow走過的坑之---數據讀取和tf中batch的使用方法


首先介紹數據讀取問題,現在TensorFlow官方推薦的數據讀取方法是使用tf.data.Dataset,具體的細節不在這里贅述,看官方文檔更清楚,這里主要記錄一下官方文檔沒有提到的坑,以示"后人"。因為是記錄踩過的坑,所以行文混亂,見諒。

I 問題背景

不感興趣的可跳過此節。

最近在研究ENAS的代碼,這個網絡的作用是基於增強學習,能夠自動生成合適的網絡結構。原作者使用TensorFlow在cifar10上成功自動生成了網絡結構,並取得了不錯的效果。

但問題來了,此時我需要將代碼轉移到自己的數據集上,都知道cifar10圖像大小是32*32,並不是特別大,所以原作者"喪心病狂"地采用了一次性將數據讀進顯存的操作,絲毫不考慮我等渣渣的感受。我的數據集原圖基本在500*800或以上,經過反復試驗,如果采用源代碼我必須將圖像通過縮放和中心裁剪到160*160才能正常運行,而且運行結果並不是很理想,十分類跑了一天左右最好的結果才30%左右。

我在想如果把圖片放大后是否會提高准確度,所以第一個坑是修改數據讀取方式,適應大數據集讀取

再仔細閱讀源代碼后我還發現作者使用了tf.train.shuffle_batch這個函數用來批量讀取,這個函數也讓我頭疼了很久,因為一直不知道它和tf.data.Dataset.batch.shuffle()有什么區別,所以第二個坑時tf.train.shuffle_batchtf.data.Dataset.batch.shuffle()到底什么關系(區別)

II tf.train.batchtf.data.Dataset.batch.shuffle()什么區別

其實這兩個談不上什么區別,因為后者是前者的升級版,233333。

官方文檔對tf.train.batch的描述是這樣的:

THIS FUNCTION IS DEPRECATED. It will be removed in a future version. Instructions for updating: Queue-based input pipelines have been replaced by tf.data. Use tf.data.Dataset.batch(batch_size) (or padded_batch(...) if dynamic_pad=True).

在這里我也推薦大家用tf.data,因為他相比於原來的tf.train.batch好用太多。

III TensorFlow如何讀取大數據集?

這里的大數據集指的是稍微比較大的,像ImageNet這樣的數據集還沒嘗試過。所以下面的方法不敢肯定是否使用於ImageNet。

要想讀取大數據集,我找到的官方給出的方案有兩種:

  • 使用TFRecord格式進行數據讀取。
  • 使用tf.placeholder,本文將主要介紹這種方法。

我的數據集是以已經分好類的文件夾進行存儲的,大致結構是這樣的

├───test
│   ├───Acne_Vulgaris
│   ├───Actinic_solar_Damage__Actinic_Keratosis
│   ├───Basal_Cell_Carcinoma
│   ├───Rosacea
│   └───Seborrheic_Keratosis
├───train
│   ├───Acne_Vulgaris
│   ├───Actinic_solar_Damage__Actinic_Keratosis
│   ├───Basal_Cell_Carcinoma
│   ├───Rosacea
│   └───Seborrheic_Keratosis
└───valid
    ├───Acne_Vulgaris
    ├───Actinic_solar_Damage__Actinic_Keratosis
    ├───Basal_Cell_Carcinoma
    ├───Rosacea
    └───Seborrheic_Keratosis

我的方法非常適合懶人,具體流程如下:

1.torchvision讀取數據

pytorch提供了torchvision這個庫,這個庫堪稱瑰寶,torchvision.datasets里有個函數是ImageFolder,你只需要指明路徑即可把圖片數據都讀進來,不用再苦逼地手寫for循環遍歷了。其他的細節,比如data augmentation等等就不介紹了,具體代碼可參看官方文檔以及如下鏈接: https://github.com/marsggbo/enas/blob/master/src/skin5_placeholder/data_utils.py

2.創建tf.placeholder

假設上一步已經圖像數據讀取完畢,並保存成numpy文件,下面參看官方文檔例子

# 讀取numpy數據
with np.load("/var/data/training_data.npy") as data:
  features = data["features"]
  labels = data["labels"]

# 查看圖像和標簽維度是否保持一致
assert features.shape[0] == labels.shape[0]

# 創建placeholder
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

# 創建dataset
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))

# 批量讀取,打散數據,repeat()
dataset = dataset.shuffle(20).batch(5).repeat()

# [Other transformations on `dataset`...]
dataset_other = ...

iterator = dataset.make_initializable_iterator()
data_element = iterator.get_nex()

sess = tf.Session()
sess.run(iterator.initializer, feed_dict={features_placeholder: features,labels_placeholder: labels})

for e in range(EPOCHS):
	for step in range(num_batches):
		x_batch, y_batch = sess.run(data_element)
		y_pred = model(x_batch)
		...
...

sess.close()

插播一條廣告:上面代碼中batch(), shuffle(), repeat()的具體用法參見Tensorflow datasets.shuffle repeat batch方法

上面邏輯很清楚:

  • 創建placeholder
  • 創建dataset
  • 然后數據打亂,批量讀取
  • 創建迭代器,使用get_next()迭代獲取下一個batch數據,這里返回的是以個tuple,即(feature_batch, label_batch)
  • 初始化迭代器,並將數據喂給placeholder,注意迭代器要在循環語句之前初始化,否則無法完整把數據集遍歷讀取一遍。
  • 進入循環語句,批量讀取數據,開始進行運算了。

注意,每次一運行sess.run(data_element)這個語句,TensorFlow會自動的調取下一個批次的數據。不僅如此,只要sess.run一個把data_element作為輸入的節點,也都會自動調取下一個批次的數據。說的有點繞,看例子就明白了

可以看到如果在讀取數據的時候還sess.run與數據有關的操作,那么有的數據就根本沒遍歷到,所以這個問題要特別注意。

那我為什么會連這種坑都能踩到呢,因為原作者的代碼寫的太“好”了,對於我這種剛入門的人來說太難理解和修改了。

原作者的代碼結構並沒有寫for循環遍歷讀取數據,然后傳入到模型。相反他把數據操作寫到了另一個類(文件)中,比如說在model.py中他定義了

class Model():
	def __init__():
		...
	
	def _model(self, img, label):
		y_pred = other_function(img)
		acc = calculate_acc(y_pred, label)
	...

然后在main.py中他只是sess.run(model.acc),即

with tf.Session() as sess:
	...
	while epoch < EPOCHS:
		global_step = sess.run(model.global_step)
		if global_step % 50:
			acc = sess.run(model.acc)
		...
	...

抱怨一下: 它這代碼結構寫得和官方文檔不一樣,所以一直不知道怎么修改。你如果從最開始看到這,你應該覺得很好改啊,但是你看着官方文檔真不知道怎么修改,因為最開始我並不知道每次sess.run之后都會自動調用下一個batch的數據,而且也還沒有習慣TensorFlow數據流的思維。在這里特別感謝這個問題幫助我解答了困惑:Tensorflow: create minibatch from numpy array > 2 GB

所以這種情況怎么讀取數據呢?很簡單,只需要在循環語句之前初始化迭代器即可。

ops = {
	"global_step": model.global_step,
	"acc": model.acc

}
with tf.Session() as sess:
	...
	sess.run(iterator.initializer, feed_dict={features_placeholder: features,labels_placeholder: labels}) 
	while epoch < EPOCHS:
		global_step = sess.run(ops['global_step'])
		if global_step % 50:
			acc = sess.run(ops['acc'])
		...
	...

如果你想要查看數據是否正確讀取,千萬不要在上面的while循環中加入這么一行代碼x_batch, y_batch=sess.run([model.x_batch, model.y_batch]),這樣就會導致上面所說的數據無法完整遍歷的問題。那怎么辦呢?

我們可以考慮修改ops來獲取數據,代碼如下:

ops = {
	"global_step": model.global_step,
	"acc": model.acc,
	"x_batch": model.x_batch,
	"y_batch": model.y_batch

}
with tf.Session() as sess:
	...
	sess.run(iterator.initializer, feed_dict={features_placeholder: features,labels_placeholder: labels}) 
	while epoch < EPOCHS:
		global_step = sess.run(ops['global_step'])
		if global_step % 50:
			
			acc = sess.run([ops["acc"], ops["x_batch"], ops["y_batch"]])
		...

這樣之所以能完整遍歷,是因為我們將x_batch和acc放在一起啦~,所以這可以看成只是一個運算。




微信公眾號:AutoML機器學習
MARSGGBO原創
如有意合作或學術討論歡迎私戳聯系~
郵箱:marsggbo@foxmail.com

2018-11-29




免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM