什么是 TFRecord
PS:這段內容摘自 http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html
一種保存記錄的方法可以允許你講任意的數據轉換為TensorFlow所支持的格式, 這種方法可以使TensorFlow的數據集更容易與網絡應用架構相匹配。這種建議的方法就是使用TFRecords文件,TFRecords文件包含了tf.train.Example 協議內存塊(protocol buffer)(協議內存塊包含了字段 Features)。你可以寫一段代碼獲取你的數據, 將數據填入到Example協議內存塊(protocolbuffer),將協議內存塊序列化為一個字符串, 並且通過tf.python_io.TFRecordWriterclass寫入到TFRecords文件。tensorflow/g3doc/how_tos/reading_data/convert_to_records.py就是這樣的一個例子。
從TFRecords文件中讀取數據, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。這個parse_single_example操作可以將Example協議內存塊(protocolbuffer)解析為張量。 MNIST的例子就使用了convert_to_records 所構建的數據。 請參看tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py,
代碼
adjust_pic.py
單純的轉換圖片大小
- # -*- coding: utf-8 -*-
- import tensorflow as tf
- def resize(img_data, width, high, method=0):
- return tf.image.resize_images(img_data,[width, high], method)
pic2tfrecords.py
將圖片保存成TFRecord
- # -*- coding: utf-8 -*-
- # 將圖片保存成 TFRecord
- import os.path
- import matplotlib.image as mpimg
- import tensorflow as tf
- import adjust_pic as ap
- from PIL import Image
- SAVE_PATH = 'data/dataset.tfrecords'
- def _int64_feature(value):
- return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
- def _bytes_feature(value):
- return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
- def load_data(datafile, width, high, method=0, save=False):
- train_list = open(datafile,'r')
- # 准備一個 writer 用來寫 TFRecord 文件
- writer = tf.python_io.TFRecordWriter(SAVE_PATH)
- with tf.Session() as sess:
- for line in train_list:
- # 獲得圖片的路徑和類型
- tmp = line.strip().split(' ')
- img_path = tmp[0]
- label = int(tmp[1])
- # 讀取圖片
- image = tf.gfile.FastGFile(img_path, 'r').read()
- # 解碼圖片(如果是 png 格式就使用 decode_png)
- image = tf.image.decode_jpeg(image)
- # 轉換數據類型
- # 因為為了將圖片數據能夠保存到 TFRecord 結構體中,所以需要將其圖片矩陣轉換成 string,所以為了在使用時能夠轉換回來,這里確定下數據格式為 tf.float32
- image = tf.image.convert_image_dtype(image, dtype=tf.float32)
- # 既然都將圖片保存成 TFRecord 了,那就先把圖片轉換成希望的大小吧
- image = ap.resize(image, width, high)
- # 執行 op: image
- image = sess.run(image)
- # 將其圖片矩陣轉換成 string
- image_raw = image.tostring()
- # 將數據整理成 TFRecord 需要的數據結構
- example = tf.train.Example(features=tf.train.Features(feature={
- 'image_raw': _bytes_feature(image_raw),
- 'label': _int64_feature(label),
- }))
- # 寫 TFRecord
- writer.write(example.SerializeToString())
- writer.close()
- load_data('train_list.txt_bak', 224, 224)
tfrecords2data.py
從TFRecord中讀取並保存成圖片
- # -*- coding: utf-8 -*-
- # 從 TFRecord 中讀取並保存圖片
- import tensorflow as tf
- import numpy as np
- SAVE_PATH = 'data/dataset.tfrecords'
- def load_data(width, high):
- reader = tf.TFRecordReader()
- filename_queue = tf.train.string_input_producer([SAVE_PATH])
- # 從 TFRecord 讀取內容並保存到 serialized_example 中
- _, serialized_example = reader.read(filename_queue)
- # 讀取 serialized_example 的格式
- features = tf.parse_single_example(
- serialized_example,
- features={
- 'image_raw': tf.FixedLenFeature([], tf.string),
- 'label': tf.FixedLenFeature([], tf.int64),
- })
- # 解析從 serialized_example 讀取到的內容
- images = tf.decode_raw(features['image_raw'], tf.uint8)
- labels = tf.cast(features['label'], tf.int64)
- with tf.Session() as sess:
- # 啟動多線程
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(sess=sess, coord=coord)
- # 因為我這里只有 2 張圖片,所以下面循環 2 次
- for i in range(2):
- # 獲取一張圖片和其對應的類型
- label, image = sess.run([labels, images])
- # 這里特別說明下:
- # 因為要想把圖片保存成 TFRecord,那就必須先將圖片矩陣轉換成 string,即:
- # pic2tfrecords.py 中 image_raw = image.tostring() 這行
- # 所以這里需要執行下面這行將 string 轉換回來,否則會無法 reshape 成圖片矩陣,請看下面的小例子:
- # a = np.array([[1, 2], [3, 4]], dtype=np.int64) # 2*2 的矩陣
- # b = a.tostring()
- # # 下面這行的輸出是 32,即: 2*2 之后還要再乘 8
- # # 如果 tostring 之后的長度是 2*2=4 的話,那可以將 b 直接 reshape([2, 2]),但現在的長度是 2*2*8 = 32,所以無法直接 reshape
- # # 同理如果你的圖片是 500*500*3 的話,那 tostring() 之后的長度是 500*500*3 后再乘上一個數
- # print len(b)
- #
- # 但在網上有很多提供的代碼里都沒有下面這一行,你們那真的能 reshape ?
- image = np.fromstring(image, dtype=np.float32)
- # reshape 成圖片矩陣
- image = tf.reshape(image, [224, 224, 3])
- # 因為要保存圖片,所以將其轉換成 uint8
- image = tf.image.convert_image_dtype(image, dtype=tf.uint8)
- # 按照 jpeg 格式編碼
- image = tf.image.encode_jpeg(image)
- # 保存圖片
- with tf.gfile.GFile('pic_%d.jpg' % label, 'wb') as f:
- f.write(sess.run(image))
- load_data(224, 224)
train_list.txt_bak 中的內容如下:
image_1093.jpg 13
image_0805.jpg 10