Tensorflow在處理數據時,經常加載圖像數據,有的時候是直接讀取文件,有的則是讀取二進制文件,為了更好的理解Tensorflow數據處理模式,先簡單講解顯示圖片機制,就能更好掌握是否讀取正確了。
一、結合opencv讀取顯示圖片
1、變量
使用tf.Variable初始化為tensor,加載到tensorflow對圖片進行顯示
def showimage_variable_opencv(filename): image = cv2.imread(filename) # Create a Tensorflow variable image_tensor = tf.Variable(image, name = 'image') with tf.Session() as sess: # image_flap = tf.transpose(image_tensor, perm = [1,0,2]) sess.run(tf.global_variables_initializer()) result = sess.run(image_tensor) cv2.imshow('result', result) cv2.waitKey(0)
2、placeholder
OpenCV讀入圖片,使用tf.placeholder符號變量加載到tensorflow里,然后tensorflow對圖片進行顯示
def showimage_placeholder_opencv(filename): image = cv2.imread(filename) # Create a Tensorflow variable image_tensor = tf.placeholder('uint8', [None, None, 3]) with tf.Session() as sess: # image_flap = tf.transpose(image_tensor, perm = [1,0,2]) # sess.run(tf.global_variables_initializer()) result = sess.run(image_tensor, feed_dict = {image_tensor:image}) cv2.imshow('result', result) cv2.waitKey(0)
上面兩個內容非常簡單,opencv讀取圖像數據,成為圖像矩陣,然后直接轉化成tensorflow的tensor形式,之后通過會話窗口輸出圖像,進而顯示。
參考文獻:TensorFlow與OpenCV,讀取圖片,進行簡單操作並顯示
二、tensorflow內部讀取文件顯示
1、tensorflow的gfile
直接使用tensorflow提供的函數image = tf.gfile.FastGFile('PATH')來讀取一副圖片:
def showimage_gfile(filename):
# 讀物文件
image = tf.gfile.FastGFile(filename, 'r').read()
# #圖像解碼
image_data = tf.image.decode_jpeg(image)
#改變圖像數據的類型
image_show = tf.image.convert_image_dtype(image_data, dtype = tf.uint8)
plt.figure(1)
with tf.Session() as sess:
plt.imshow(image_show.eval())
sess.run(image_show)
2.string_input_producer
將圖像加載到創建好的隊列中使用tf.train.string_input_producer(),然后再加載到變量當中:
def showimage_string_input(filename): # 函數接受文件列表,如果是文件名需要加[] file_queue = tf.train.string_input_producer([filename]) # 定義讀入器,並讀入文件緩存器 image_reader = tf.WholeFileReader() _, image = image_reader.read(file_queue) image = tf.image.decode_jpeg(image) with tf.Session() as sess: # 初始化協同線程 coord = tf.train.Coordinator() # 啟動線程 threads = tf.train.start_queue_runners(sess = sess, coord = coord) result = sess.run(image) coord.request_stop() coord.join(threads) image_uint8 = tf.image.convert_image_dtype(image, dtype = tf.uint8) plt.imshow(image_uint8.eval())
cv2.imshow('result', result)
cv2.waitKey(0)
string_input_producer來生成一個先入先出的隊列, 文件閱讀器會需要它來讀取數據。
string_input_producer 提供的可配置參數來設置文件名亂序和最大的訓練迭代數, QueueRunner會為每次迭代(epoch)將所有的文件名加入文件名隊列中, 如果shuffle=True的話, 會對文件名進行亂序處理。這一過程是比較均勻的,因此它可以產生均衡的文件名隊列。
這個QueueRunner的工作線程是獨立於文件閱讀器的線程, 因此亂序和將文件名推入到文件名隊列這些過程不會阻塞文件閱讀器運行。
這個函數在tensorflow應用非常重要,用函數目的是為了將文件列表預先加載文件內出表中,方便文件讀取,減少讀取數據的時間,具體簡介可以參考這篇文章:
文章中對文件保存和讀取講解很清晰,此處重點引用過來
我們用一個具體的例子感受TensorFlow中的數據讀取。如圖,假設我們在當前文件夾中已經有A.jpg、B.jpg、C.jpg三張圖片,我們希望讀取這三張圖片5個epoch並且把讀取的結果重新存到read文件夾中。
# 導入TensorFlow import TensorFlow as tf # 新建一個Session with tf.Session() as sess: # 我們要讀三幅圖片A.jpg, B.jpg, C.jpg filename = ['A.jpg', 'B.jpg', 'C.jpg'] # string_input_producer會產生一個文件名隊列 filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=5) # reader從文件名隊列中讀數據。對應的方法是reader.read reader = tf.WholeFileReader() key, value = reader.read(filename_queue) # tf.train.string_input_producer定義了一個epoch變量,要對它進行初始化 tf.local_variables_initializer().run() # 使用start_queue_runners之后,才會開始填充隊列 threads = tf.train.start_queue_runners(sess=sess) i = 0 while True: i += 1 # 獲取圖片數據並保存 image_data = sess.run(value) with open('read/test_%d.jpg' % i, 'wb') as f: f.write(image_data)
運行代碼后,我們得到就可以看到read文件夾中的圖片,正好是按順序的5個epoch:

如果我們設置filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=5)中的shuffle=True,那么在每個epoch內圖像就會被打亂,如圖所示:

我們這里只是用三張圖片舉例,實際應用中一個數據集肯定不止3張圖片,不過涉及到的原理都是共通的。
3、slice_input_producer
從眾多tensorlist里面,隨機選取一個tensor
def saveimages_slice_input(filenames): labels = [1,2,3] images_tensor = tf.convert_to_tensor(filenames, dtype=tf.string) labels_tensor = tf.convert_to_tensor(labels, dtype=tf.uint8) file = tf.train.slice_input_producer([images_tensor, labels_tensor]) image_content = tf.read_file(file[0]) index = file[1] image_data = tf.image.convert_image_dtype(tf.image.decode_jpeg(image_content), tf.uint8) with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess = sess, coord = coord) sess.run(image_data) print(sess.run(index)) coord.request_stop() coord.join(threads) plt.imshow(image_data.eval())
其實這個是每次只能抽取一個圖片,如果需要讀取多個圖片還是需要構建batch,后面講對其進行詳細的講解。
其中image_content = tf.read_file(file[0]),必須添加為file[0]才能正確圖像數據,進而顯示圖像。file[1]則是輸出數字。
