關於TensorFlow讀取數據,官網給出了三種方法:
- 供給數據(Feeding):在TensorFlow程序運行的每一步,讓python代碼來供給數據。
- 從文件讀取數據:在TensorFlow圖的起始,讓一個輸入管線從文件中讀取數據。
- 預加載數據:在TensorFlow圖中定義常量或變量來保存所有數據(僅適用於數據量比較小的情況)。
對於數據量較小而言,可能一般選擇直接將數據加載進內存,然后再分batch輸入網絡進行訓練(tip:使用這種方法時,結合yeild 使用更為簡潔)。但是如果數據量較大,這樣的方法就不適用了。因為太耗內存,所以這時最好使用TensorFlow提供的隊列queue,也就是第二種方法:從文件讀取數據。對於一些特定的讀取,比如csv文件格式,官網有相關的描述,在這里我們學習一種比較通用的,高效的讀取方法,即使用TensorFlow內定標准格式——TFRecords。
1,什么是TFRecords?
TensorFlow提供了一種統一的格式來存儲數據,這個格式就是TFRecords。
為了高效的讀取數據,可以將數據進行序列化存儲,這樣也便於網絡流式讀取數據,TFRecord就是一種保存記錄的方法可以允許你講任意的數據轉換為TensorFlow所支持的格式,這種方法可以使TensorFlow的數據集更容易與網絡應用架構相匹配。
TFRecord是谷歌推薦的一種常用的存儲二進制序列數據的文件格式,理論上它可以保存任何格式的信息。下面是Tensorflow的官網給出的文檔結構,整個文件由文件長度信息,長度校驗碼,數據,數據校驗碼組成。
uint64 length uint32 masked_crc32_of_length byte data[length] uint32 masked_crc32_of_data
但是對於我們普通開發者而言,我們並不需要關心這些,TensorFlow提供了豐富的API可以幫助我們輕松地讀寫TFRecord文件。
而 tf.Example 類就是一種將數據表示為{‘string’: value}形式的 message類型,TensorFlow經常使用 tf.Example 來寫入,讀取 TFRecord數據。
1.1 tf.Example 可以使用的數據格式
通常情況下,tf.Example中可以使用以下幾種格式:
- tf.train.BytesList: 可以使用的類型包括 string和byte
- tf.train.FloatList: 可以使用的類型包括 float和double
- tf.train.Int64List: 可以使用的類型包括 enum,bool, int32, uint32, int64
TFRecord支持寫入三種格式的數據:string,int64,float32,以列表的形式分別通過tf.train.BytesList,tf.train.Int64List,tf.train.FloatList 寫入 tf.train.Feature,如下所示:
#feature一般是多維數組,要先轉為list tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()])) #tostring函數后feature的形狀信息會丟失,把shape也寫入 tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape))) tf.train.Feature(float_list=tf.train.FloatList(value=[label]))
如果寫成這樣,可能大家更熟悉一點:
def _bytes_feature(value):
"""Returns a bytes_list from a string/byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Return a float_list form a float/double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Return a int64_list from a bool/enum/int/uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
通過上述操作,我們以dict的形式把要寫入的數據匯總,並構建 tf.train.Features,然后構建 tf.train.Example。如下:
def get_tfrecords_example(feature, label):
tfrecords_features = {}
feat_shape = feature.shape
tfrecords_features['feature'] = tf.train.Feature(bytes_list=
tf.train.BytesList(value=[feature.tostring()]))
tfrecords_features['shape'] = tf.train.Feature(int64_list=
tf.train.Int64List(value=list(feat_shape)))
tfrecords_features['label'] = tf.train.Feature(float_list=
tf.train.FloatList(value=label))
return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
我們測試一下,來驗證不同的數據格式需要使用不同的函數:
# tf.train.BytesList
print(_bytes_feature(b'test_string'))
print(_bytes_feature('test_string'.encode('utf8')))
# tf.train.FloatList
print(_float_feature(np.exp(1)))
# tf.train.Int64List
print(_int64_feature(True))
print(_int64_feature(1))
結果:
bytes_list {
value: "test_string"
}
bytes_list {
value: "test_string"
}
float_list {
value: 2.7182817459106445
}
int64_list {
value: 1
}
int64_list {
value: 1
}
把創建的tf.train.Example序列化下,便可以通過 tf.python_io.TFRecordWriter 寫入 tfrecord文件中,如下:
#創建tfrecord的writer,文件名為xxx
tfrecord_wrt = tf.python_io.TFRecordWriter('xxx.tfrecord')
#把數據寫入Example
exmp = get_tfrecords_example(feats[inx], labels[inx])
#Example序列化
exmp_serial = exmp.SerializeToString()
#寫入tfrecord文件
tfrecord_wrt.write(exmp_serial)
#寫完后關閉tfrecord的writer
tfrecord_wrt.close()
TFRecord 的核心內容在於內部有一系列的Example,Example 是protocolbuf 協議(protocolbuf 是通用的協議格式,對主流的編程語言都適用。所以這些 List對應到Python語言當中是列表。而對於Java 或者 C/C++來說他們就是數組)下的消息體。
一個Example消息體包含了一系列的feature屬性。每一個feature是一個map,也就是 key-value 的鍵值對。key 取值是String類型。而value是Feature類型的消息體。下面代碼給出了 tf.train.Example的定義:
message Example {
Features features = 1;
};
message Features{
map<string,Feature> featrue = 1;
};
message Feature{
oneof kind{
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
從上面的代碼可以看出 tf.train.example 的數據結構是比較簡潔的。tf.train.example中包含了一個從屬性名稱到取值的字典。其中屬性名稱為一個字符串,屬性的取值為字符串(ByteList),實數列表(FloatList)或者整數列表(Int64List),舉個例子,比如將一張解碼前的圖像存為一個字符串,圖像所對應的類別編碼存為整數列表,所以可以說TFRecord 可以存儲幾乎任何格式的信息。
2,為什么要用TFRecord?
TFRerecord也不是非用不可,但確實是谷歌官網推薦的文件格式。
- 1,它特別適合於TensorFlow,或者說就是為TensorFlow量身打造的。
- 2,因為TensorFlow開發者眾多,統一訓練的數據文件格式是一件很有意義的事情,也有助於降低學習成本和遷移成本。
TFRecords 其實是一種二進制文件,雖然它不如其他格式好理解,但是它能更好的利用內存,更方便賦值和移動,並且不需要單獨的標簽文件,理論上,它能保存所有的信息。總而言之,這樣的文件格式好處多多,所以讓我們利用起來。
3,為什么要生成自己的圖片數據集TFrecords?
使用TensorFlow進行網格訓練時,為了提高讀取數據的效率,一般建議將訓練數據轉化為TFrecords格式。
使用tensorflow官網例子練習,我們會發現基本都是MNIST,CIFAR_10這種做好的數據集說事。所以對於我們這些初學者,完全不知道圖片該如何輸入。這時候學習自己制作數據集就非常有必要了。
4,如何將一張圖片和一個TFRecord 文件相互轉化
我們可以使用TFWriter輕松的完成這個任務。但是制作之前,我們要明確自己的目的。我們必須要想清楚,需要把什么信息存儲到TFRecord 文件當中,這其實是最重要的。
下面我們將一張圖片轉化為TFRecord,然后讀取一張TFRecord文件,並展示為圖片。
4.1 將一張圖片轉化成TFRecord 文件
下面舉例說明嘗試把圖片轉化成TFRecord 文件。
首先定義Example 消息體。
Example Message {
Features{
feature{
key:"name"
value:{
bytes_list:{
value:"cat"
}
}
}
feature{
key:"shape"
value:{
int64_list:{
value:689
value:720
value:3
}
}
}
feature{
key:"data"
value:{
bytes_list:{
value:0xbe
value:0xb2
...
value:0x3
}
}
}
}
}
上面的Example表示,要將一張 cat 圖片信息寫進了 TFRecord 當中。而圖片信息包含了圖片的名字,圖片的維度信息還有圖片的數據,分別對應了 name,shape,content 3個feature。
下面我們嘗試使用代碼實現:
# _*_coding:utf-8_*_
import tensorflow as tf
def write_test(input, output):
# 借助於TFRecordWriter 才能將信息寫入TFRecord 文件
writer = tf.python_io.TFRecordWriter(output)
# 讀取圖片並進行解碼
image = tf.read_file(input)
image = tf.image.decode_jpeg(image)
with tf.Session() as sess:
image = sess.run(image)
shape = image.shape
# 將圖片轉換成string
image_data = image.tostring()
print(type(image))
print(len(image_data))
name = bytes('cat', encoding='utf-8')
print(type(name))
# 創建Example對象,並將Feature一一對應填充進去
example = tf.train.Example(features=tf.train.Features(feature={
'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))
}
))
# 將example序列化成string 類型,然后寫入。
writer.write(example.SerializeToString())
writer.close()
if __name__ == '__main__':
input_photo = 'cat.jpg'
output_file = 'cat.tfrecord'
write_test(input_photo, output_file)
上述代碼注釋比較詳細,所以我們就重點說一下下面三點:
- 1,將圖片解碼,然后轉化成string數據,然后填充進去。
- 2,Feature 的value 是列表,所以記得加上 []
- 3,example需要調用 SerializetoString() 進行序列化后才行
4.2 TFRecord 文件讀取為圖片
我們將圖片的信息寫入到一個tfrecord文件當中。現在我們需要檢驗它是否正確。這就需要用到如何讀取TFRecord 文件的知識點了。
代碼如下:
# _*_coding:utf-8_*_
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
def _parse_record(example_photo):
features = {
'name': tf.FixedLenFeature((), tf.string),
'shape': tf.FixedLenFeature([3], tf.int64),
'data': tf.FixedLenFeature((), tf.string)
}
parsed_features = tf.parse_single_example(example_photo,features=features)
return parsed_features
def read_test(input_file):
# 用dataset讀取TFRecords文件
dataset = tf.data.TFRecordDataset(input_file)
dataset = dataset.map(_parse_record)
iterator = dataset.make_one_shot_iterator()
with tf.Session() as sess:
features = sess.run(iterator.get_next())
name = features['name']
name = name.decode()
img_data = features['data']
shape = features['shape']
print("==============")
print(type(shape))
print(len(img_data))
# 從bytes數組中加載圖片原始數據,並重新reshape,它的結果是 ndarray 數組
img_data = np.fromstring(img_data, dtype=np.uint8)
image_data = np.reshape(img_data, shape)
plt.figure()
# 顯示圖片
plt.imshow(image_data)
plt.show()
# 將數據重新編碼成jpg圖片並保存
img = tf.image.encode_jpeg(image_data)
tf.gfile.GFile('cat_encode.jpg', 'wb').write(img.eval())
if __name__ == '__main__':
read_test("cat.tfrecord")
下面解釋一下代碼:
1,首先使用dataset去讀取tfrecord文件
2,在解析example 的時候,用現成的API:tf.parse_single_example
3,用 np.fromstring() 方法就可以獲取解析后的string數據,記得把數據還原成 np.uint8
4,用 tf.image.encode_jepg() 方法可以將圖片數據編碼成 jpeg 格式
5,用 tf.gfile.GFile 對象可以把圖片數據保存到本地
6,因為將圖片 shape 寫入了example 中,所以解析的時候必須指定維度,在這里 [3],不然程序會報錯。
運行程序后,可以看到圖片顯示如下:

5,如何將一個文件夾下多張圖片和一個TFRecord 文件相互轉化
下面我們將一個文件夾的圖片轉化為TFRecord,然后再將TFRecord讀取為圖片。
5.1 將一個文件夾下多張圖片轉化為一個TFRecord文件
下面舉例說明嘗試把圖片轉化成TFRecord 文件。
# _*_coding:utf-8_*_
# 將圖片保存成TFRecords
import os
import tensorflow as tf
from PIL import Image
import random
import cv2
import numpy as np
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 生成字符串型的屬性
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 生成實數型的屬性
def float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def read_image(filename, resize_height, resize_width, normalization=False):
'''
讀取圖片數據,默認返回的是uint8, [0, 255]
:param filename:
:param resize_height:
:param resize_width:
:param normalization: 是否歸一化到 [0.0, 1.0]
:return: 返回的圖片數據
'''
bgr_image = cv2.imread(filename)
# print(type(bgr_image))
# 若是灰度圖則轉化為三通道
if len(bgr_image.shape) == 2:
print("Warning:gray image", filename)
bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)
# 將BGR轉化為RGB
rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
# show_image(filename, rgb_image)
# rgb_image=Image.open(filename)
if resize_width > 0 and resize_height > 0:
rgb_image = cv2.resize(rgb_image, (resize_width, resize_height))
rgb_image = np.asanyarray(rgb_image)
if normalization:
rgb_image = rgb_image / 255.0
return rgb_image
def load_labels_file(filename, labels_num=1, shuffle=False):
'''
載圖txt文件,文件中每行為一個圖片信息,且以空格隔開,圖像路徑 標簽1 標簽2
如 test_image/1.jpg 0 2
:param filename:
:param labels_num: labels個數
:param shuffle: 是否打亂順序
:return: images type-> list
:return:labels type->lis\t
'''
images = []
labels = []
with open(filename) as f:
lines_list = f.readlines()
# print(lines_list) # ['plane\\0499.jpg 4\n', 'plane\\0500.jpg 4\n']
if shuffle:
random.shuffle(lines_list)
for lines in lines_list:
line = lines.rstrip().split(" ") # rstrip 刪除 string 字符串末尾的空格. ['plane\\0006.jpg', '4']
label = []
for i in range(labels_num): # labels_num 1 0 1所以i只能取1
label.append(int(line[i + 1])) # 確保讀取的是列表的第二個元素
# print(label)
images.append(line[0])
# labels.append(line[1]) # ['0', '4']
labels.append(label)
# print(images)
# print(labels)
return images, labels
def create_records(image_dir, file, output_record_dir, resize_height, resize_width, shuffle, log=5):
'''
實現將圖像原始數據,label,長,寬等信息保存為record文件
注意:讀取的圖像數據默認是uint8,再轉為tf的字符串型BytesList保存,解析請需要根據需要轉換類型
:param image_dir:原始圖像的目錄
:param file:輸入保存圖片信息的txt文件(image_dir+file構成圖片的路徑)
:param output_record_dir:保存record文件的路徑
:param resize_height:
:param resize_width:
PS:當resize_height或者resize_width=0是,不執行resize
:param shuffle:是否打亂順序
:param log:log信息打印間隔
'''
# 加載文件,僅獲取一個label
images_list, labels_list = load_labels_file(file, 1, shuffle)
writer = tf.python_io.TFRecordWriter(output_record_dir)
for i, [image_name, labels] in enumerate(zip(images_list, labels_list)):
image_path = os.path.join(image_dir, images_list[i])
if not os.path.exists(image_path):
print("Error:no image", image_path)
continue
image = read_image(image_path, resize_height, resize_width)
image_raw = image.tostring()
if i % log == 0 or i == len(images_list) - 1:
print("-----------processing:%d--th------------" % (i))
print('current image_path=%s' % (image_path), 'shape:{}'.format(image.shape),
'labels:{}'.format(labels))
# 這里僅保存一個label,多label適當增加"'label': _int64_feature(label)"項
label = labels[0]
example = tf.train.Example(features=tf.train.Features(feature={
'image_raw': _bytes_feature(image_raw),
'height': _int64_feature(image.shape[0]),
'width': _int64_feature(image.shape[1]),
'depth': _int64_feature(image.shape[2]),
'label': _int64_feature(label)
}))
writer.write(example.SerializeToString())
writer.close()
def get_example_nums(tf_records_filenames):
'''
統計tf_records圖像的個數(example)個數
:param tf_records_filenames: tf_records文件路徑
:return:
'''
nums = 0
for record in tf.python_io.tf_record_iterator(tf_records_filenames):
nums += 1
return nums
if __name__ == '__main__':
resize_height = 224 # 指定存儲圖片高度
resize_width = 224 # 指定存儲圖片寬度
shuffle = True
log = 5
image_dir = 'dataset/train'
train_labels = 'dataset/train.txt'
train_record_output = 'train.tfrecord'
create_records(image_dir, train_labels, train_record_output, resize_height, resize_width, shuffle, log)
train_nums = get_example_nums(train_record_output)
print("save train example nums={}".format(train_nums))
5.2 將一個TFRecord文件轉化為圖片顯示
因為圖片太多,所以我們這里只展示每個文件夾中第一張圖片即可。
代碼如下:
# _*_coding:utf-8_*_
# 將圖片保存成TFRecords
import os
import tensorflow as tf
from PIL import Image
import random
import cv2
import numpy as np
import matplotlib.pyplot as plt
def read_records(filename,resize_height, resize_width,type=None):
'''
解析record文件:源文件的圖像數據是RGB,uint8,[0,255],一般作為訓練數據時,需要歸一化到[0,1]
:param filename:
:param resize_height:
:param resize_width:
:param type:選擇圖像數據的返回類型
None:默認將uint8-[0,255]轉為float32-[0,255]
normalization:歸一化float32-[0,1]
centralization:歸一化float32-[0,1],再減均值中心化
:return:
'''
# 創建文件隊列,不限讀取的數量
filename_queue = tf.train.string_input_producer([filename])
# 為文件隊列創建一個閱讀區
reader = tf.TFRecordReader()
# reader從文件隊列中讀入一個序列化的樣本
_, serialized_example = reader.read(filename_queue)
# 解析符號化的樣本
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([], tf.int64)
}
)
# 獲得圖像原始的數據
tf_image = tf.decode_raw(features["image_raw"], tf.uint8)
tf_height = features['height']
tf_width = features['width']
tf_depth = features['depth']
tf_label = tf.cast(features['label'], tf.int32)
#PS 回復原始圖像 reshpe的大小必須與保存之前的圖像shape一致,否則報錯
# 設置圖像的維度
tf_image = tf.reshape(tf_image, [resize_height, resize_width, 3])
# 恢復數據后,才可以對圖像進行resize_images:輸入 uint 輸出 float32
# tf_image = tf.image.resize_images(tf_image, [224, 224])
# 存儲的圖像類型為 uint8 tensorflow訓練數據必須是tf.float32
if type is None:
tf_image = tf.cast(tf_image, tf.float32)
# 【1】 若需要歸一化的話請使用
elif type == 'normalization':
# 僅當輸入數據是 uint8,才會歸一化 [0 , 255]
tf_image = tf.cast(tf_image, tf.float32) * (1. / 255.0)
elif type=='centralization':
# 若需要歸一化,且中心化,假設均值為0.5 請使用
tf_image = tf.cast(tf_image, tf.float32) * (1. / 255.0) - 0.5
# 這里僅僅返回圖像和標簽
return tf_image, tf_label
def show_image(title, image):
'''
顯示圖片
:param title: 圖像標題
:param image: 圖像的數據
:return:
'''
plt.imshow(image)
plt.axis('on') # 關掉坐標軸 為 off
plt.title(title) # 圖像題目
plt.show()
def disp_records(record_file,resize_height, resize_width,show_nums=4):
'''
解析record文件,並顯示show_nums張圖片,主要用於驗證生成record文件是否成功
:param tfrecord_file: record文件路徑
:return:
'''
# 讀取record 函數
tf_image, tf_label = read_records(record_file, resize_height, resize_width, type='normalization')
# 顯示前4個圖片
init_op = tf.global_variables_initializer()
# init_op = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(show_nums): # 在會話中取出image和label
image, label = sess.run([tf_image, tf_label])
# image = tf_image.eval()
# 直接從record解析的image是一個向量,需要reshape顯示
# image = image.reshape([height,width,depth])
print('shape:{},tpye:{},labels:{}'.format(image.shape, image.dtype, label))
# pilimg = Image.fromarray(np.asarray(image_eval_reshape))
# pilimg.show()
show_image("image:%d"%(label), image)
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
resize_height = 224 # 指定存儲圖片高度
resize_width = 224 # 指定存儲圖片寬度
shuffle = True
log = 5
image_dir = 'dataset/train'
train_labels = 'dataset/train.txt'
train_record_output = 'train.tfrecord'
# 測試顯示函數
disp_records(train_record_output, resize_height, resize_width)
部分代碼解析:
5.3,加入隊列
with tf.Session() as sess:
sess.run(init_op)
coord = tf.train.Coordinator()
# 啟動隊列
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(show_nums): # 在會話中取出image和label
image, label = sess.run([tf_image, tf_label])
注意,啟動隊列那條code不能忘記,不然會卡死,這樣加入后,就可以做到和tensorflow官網一樣的二進制數據集了。
6,生成分割多個record文件
當圖片數據很多時候,會導致單個record文件超級巨大的情況,解決方法就是,將數據分成多個record文件保存,讀取時,只需要將多個record文件的路徑列表交給“tf.train.string_input_producer”,
完整代碼如下:(此處來自 此博客)
# -*-coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import os
import cv2
import math
import matplotlib.pyplot as plt
import random
from PIL import Image
##########################################################################
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 生成字符串型的屬性
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 生成實數型的屬性
def float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def show_image(title,image):
'''
顯示圖片
:param title: 圖像標題
:param image: 圖像的數據
:return:
'''
# plt.figure("show_image")
# print(image.dtype)
plt.imshow(image)
plt.axis('on') # 關掉坐標軸為 off
plt.title(title) # 圖像題目
plt.show()
def load_labels_file(filename,labels_num=1):
'''
載圖txt文件,文件中每行為一個圖片信息,且以空格隔開:圖像路徑 標簽1 標簽2,如:test_image/1.jpg 0 2
:param filename:
:param labels_num :labels個數
:return:images type->list
:return:labels type->list
'''
images=[]
labels=[]
with open(filename) as f:
for lines in f.readlines():
line=lines.rstrip().split(' ')
label=[]
for i in range(labels_num):
label.append(int(line[i+1]))
images.append(line[0])
labels.append(label)
return images,labels
def read_image(filename, resize_height, resize_width):
'''
讀取圖片數據,默認返回的是uint8,[0,255]
:param filename:
:param resize_height:
:param resize_width:
:return: 返回的圖片數據是uint8,[0,255]
'''
bgr_image = cv2.imread(filename)
if len(bgr_image.shape)==2:#若是灰度圖則轉為三通道
print("Warning:gray image",filename)
bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)
rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)#將BGR轉為RGB
# show_image(filename,rgb_image)
# rgb_image=Image.open(filename)
if resize_height>0 and resize_width>0:
rgb_image=cv2.resize(rgb_image,(resize_width,resize_height))
rgb_image=np.asanyarray(rgb_image)
# show_image("src resize image",image)
return rgb_image
def create_records(image_dir,file, record_txt_path, batchSize,resize_height, resize_width):
'''
實現將圖像原始數據,label,長,寬等信息保存為record文件
注意:讀取的圖像數據默認是uint8,再轉為tf的字符串型BytesList保存,解析請需要根據需要轉換類型
:param image_dir:原始圖像的目錄
:param file:輸入保存圖片信息的txt文件(image_dir+file構成圖片的路徑)
:param output_record_txt_dir:保存record文件的路徑
:param batchSize: 每batchSize個圖片保存一個*.tfrecords,避免單個文件過大
:param resize_height:
:param resize_width:
PS:當resize_height或者resize_width=0是,不執行resize
'''
if os.path.exists(record_txt_path):
os.remove(record_txt_path)
setname, ext = record_txt_path.split('.')
# 加載文件,僅獲取一個label
images_list, labels_list=load_labels_file(file,1)
sample_num = len(images_list)
# 打亂樣本的數據
# random.shuffle(labels_list)
batchNum = int(math.ceil(1.0 * sample_num / batchSize))
for i in range(batchNum):
start = i * batchSize
end = min((i + 1) * batchSize, sample_num)
batch_images = images_list[start:end]
batch_labels = labels_list[start:end]
# 逐個保存*.tfrecords文件
filename = setname + '{0}.tfrecords'.format(i)
print('save:%s' % (filename))
writer = tf.python_io.TFRecordWriter(filename)
for i, [image_name, labels] in enumerate(zip(batch_images, batch_labels)):
image_path=os.path.join(image_dir,batch_images[i])
if not os.path.exists(image_path):
print('Err:no image',image_path)
continue
image = read_image(image_path, resize_height, resize_width)
image_raw = image.tostring()
print('image_path=%s,shape:( %d, %d, %d)' % (image_path,image.shape[0], image.shape[1], image.shape[2]),'labels:',labels)
# 這里僅保存一個label,多label適當增加"'label': _int64_feature(label)"項
label=labels[0]
example = tf.train.Example(features=tf.train.Features(feature={
'image_raw': _bytes_feature(image_raw),
'height': _int64_feature(image.shape[0]),
'width': _int64_feature(image.shape[1]),
'depth': _int64_feature(image.shape[2]),
'label': _int64_feature(label)
}))
writer.write(example.SerializeToString())
writer.close()
# 用txt保存*.tfrecords文件列表
# record_list='{}.txt'.format(setname)
with open(record_txt_path, 'a') as f:
f.write(filename + '\n')
def read_records(filename,resize_height, resize_width):
'''
解析record文件
:param filename:保存*.tfrecords文件的txt文件路徑
:return:
'''
# 讀取txt中所有*.tfrecords文件
with open(filename, 'r') as f:
lines = f.readlines()
files_list=[]
for line in lines:
files_list.append(line.rstrip())
# 創建文件隊列,不限讀取的數量
filename_queue = tf.train.string_input_producer(files_list,shuffle=False)
# create a reader from file queue
reader = tf.TFRecordReader()
# reader從文件隊列中讀入一個序列化的樣本
_, serialized_example = reader.read(filename_queue)
# get feature from serialized example
# 解析符號化的樣本
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([], tf.int64)
}
)
tf_image = tf.decode_raw(features['image_raw'], tf.uint8)#獲得圖像原始的數據
tf_height = features['height']
tf_width = features['width']
tf_depth = features['depth']
tf_label = tf.cast(features['label'], tf.int32)
# tf_image=tf.reshape(tf_image, [-1]) # 轉換為行向量
tf_image=tf.reshape(tf_image, [resize_height, resize_width, 3]) # 設置圖像的維度
# 存儲的圖像類型為uint8,這里需要將類型轉為tf.float32
# tf_image = tf.cast(tf_image, tf.float32)
# [1]若需要歸一化請使用:
tf_image = tf.image.convert_image_dtype(tf_image, tf.float32)# 歸一化
# tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) # 歸一化
# [2]若需要歸一化,且中心化,假設均值為0.5,請使用:
# tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) - 0.5 #中心化
return tf_image, tf_height,tf_width,tf_depth,tf_label
def disp_records(record_file,resize_height, resize_width,show_nums=4):
'''
解析record文件,並顯示show_nums張圖片,主要用於驗證生成record文件是否成功
:param tfrecord_file: record文件路徑
:param resize_height:
:param resize_width:
:param show_nums: 默認顯示前四張照片
:return:
'''
tf_image, tf_height, tf_width, tf_depth, tf_label = read_records(record_file,resize_height, resize_width) # 讀取函數
# 顯示前show_nums個圖片
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(show_nums):
image,height,width,depth,label = sess.run([tf_image,tf_height,tf_width,tf_depth,tf_label]) # 在會話中取出image和label
# image = tf_image.eval()
# 直接從record解析的image是一個向量,需要reshape顯示
# image = image.reshape([height,width,depth])
print('shape:',image.shape,'label:',label)
# pilimg = Image.fromarray(np.asarray(image_eval_reshape))
# pilimg.show()
show_image("image:%d"%(label),image)
coord.request_stop()
coord.join(threads)
def batch_test(record_file,resize_height, resize_width):
'''
:param record_file: record文件路徑
:param resize_height:
:param resize_width:
:return:
:PS:image_batch, label_batch一般作為網絡的輸入
'''
tf_image,tf_height,tf_width,tf_depth,tf_label = read_records(record_file,resize_height, resize_width) # 讀取函數
# 使用shuffle_batch可以隨機打亂輸入:
# shuffle_batch用法:https://blog.csdn.net/u013555719/article/details/77679964
min_after_dequeue = 100#該值越大,數據越亂,必須小於capacity
batch_size = 4
# capacity = (min_after_dequeue + (num_threads + a small safety margin∗batchsize)
capacity = min_after_dequeue + 3 * batch_size#容量:一個整數,隊列中的最大的元素數
image_batch, label_batch = tf.train.shuffle_batch([tf_image, tf_label],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue)
init = tf.global_variables_initializer()
with tf.Session() as sess: # 開始一個會話
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(4):
# 在會話中取出images和labels
images, labels = sess.run([image_batch, label_batch])
# 這里僅顯示每個batch里第一張圖片
show_image("image", images[0, :, :, :])
print(images.shape, labels)
# 停止所有線程
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
# 參數設置
image_dir='dataset/train'
train_file = 'dataset/train.txt' # 圖片路徑
output_record_txt = 'dataset/record/record.txt'#指定保存record的文件列表
resize_height = 224 # 指定存儲圖片高度
resize_width = 224 # 指定存儲圖片寬度
batchSize=8000 #batchSize一般設置為8000,即每batchSize張照片保存為一個record文件
# 產生record文件
create_records(image_dir=image_dir,
file=train_file,
record_txt_path=output_record_txt,
batchSize=batchSize,
resize_height=resize_height,
resize_width=resize_width)
# 測試顯示函數
disp_records(output_record_txt,resize_height, resize_width)
# batch_test(output_record_txt,resize_height, resize_width)
7,直接讀取文件的方式
之前,我們都是將數據轉存為tfrecord文件,訓練時候再去讀取,如果不想轉為record文件,想直接讀取圖像文件進行訓練,可以使用下面的方法:
filename.txt
0.jpg 0 1.jpg 0 2.jpg 0 3.jpg 0 4.jpg 0 5.jpg 1 6.jpg 1 7.jpg 1 8.jpg 1 9.jpg 1
代碼如下:
# -*-coding: utf-8 -*-
import tensorflow as tf
import glob
import numpy as np
import os
import matplotlib.pyplot as plt
import cv2
def show_image(title, image):
'''
顯示圖片
:param title: 圖像標題
:param image: 圖像的數據
:return:
'''
# plt.imshow(image, cmap='gray')
plt.imshow(image)
plt.axis('on') # 關掉坐標軸為 off
plt.title(title) # 圖像題目
plt.show()
def tf_read_image(filename, resize_height, resize_width):
'''
讀取圖片
:param filename:
:param resize_height:
:param resize_width:
:return:
'''
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
# tf_image = tf.cast(image_decoded, tf.float32)
tf_image = tf.cast(image_decoded, tf.float32) * (1. / 255.0) # 歸一化
if resize_width>0 and resize_height>0:
tf_image = tf.image.resize_images(tf_image, [resize_height, resize_width])
# tf_image = tf.image.per_image_standardization(tf_image) # 標准化[0,1](減均值除方差)
return tf_image
def get_batch_images(image_list, label_list, batch_size, labels_nums, resize_height, resize_width, one_hot=False, shuffle=False):
'''
:param image_list:圖像
:param label_list:標簽
:param batch_size:
:param labels_nums:標簽個數
:param one_hot:是否將labels轉為one_hot的形式
:param shuffle:是否打亂順序,一般train時shuffle=True,驗證時shuffle=False
:return:返回batch的images和labels
'''
# 生成隊列
image_que, tf_label = tf.train.slice_input_producer([image_list, label_list], shuffle=shuffle)
tf_image = tf_read_image(image_que, resize_height, resize_width)
min_after_dequeue = 200
capacity = min_after_dequeue + 3 * batch_size # 保證capacity必須大於min_after_dequeue參數值
if shuffle:
images_batch, labels_batch = tf.train.shuffle_batch([tf_image, tf_label],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue)
else:
images_batch, labels_batch = tf.train.batch([tf_image, tf_label],
batch_size=batch_size,
capacity=capacity)
if one_hot:
labels_batch = tf.one_hot(labels_batch, labels_nums, 1, 0)
return images_batch, labels_batch
def load_image_labels(filename):
'''
載圖txt文件,文件中每行為一個圖片信息,且以空格隔開:圖像路徑 標簽1,如:test_image/1.jpg 0
:param filename:
:return:
'''
images_list = []
labels_list = []
with open(filename) as f:
lines = f.readlines()
for line in lines:
# rstrip:用來去除結尾字符、空白符(包括\n、\r、\t、' ',即:換行、回車、制表符、空格)
content = line.rstrip().split(' ')
name = content[0]
labels = []
for value in content[1:]:
labels.append(int(value))
images_list.append(name)
labels_list.append(labels)
return images_list, labels_list
def batch_test(filename, image_dir):
labels_nums = 2
batch_size = 4
resize_height = 200
resize_width = 200
image_list, label_list = load_image_labels(filename)
image_list=[os.path.join(image_dir,image_name) for image_name in image_list]
image_batch, labels_batch = get_batch_images(image_list=image_list,
label_list=label_list,
batch_size=batch_size,
labels_nums=labels_nums,
resize_height=resize_height, resize_width=resize_width,
one_hot=False, shuffle=True)
with tf.Session() as sess: # 開始一個會話
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(4):
# 在會話中取出images和labels
images, labels = sess.run([image_batch, labels_batch])
# 這里僅顯示每個batch里第一張圖片
show_image("image", images[0, :, :, :])
print('shape:{},tpye:{},labels:{}'.format(images.shape, images.dtype, labels))
# 停止所有線程
coord.request_stop()
coord.join(threads)
if __name__ == "__main__":
image_dir = "./dataset/train"
filename = "./dataset/train.txt"
batch_test(filename, image_dir)
8,數據輸入管道:pipeline機制解釋如下:
TensorFlow引入了tf.data.Dataset模塊,使其數據讀入的操作變得更為方便,而支持多線程(進程)的操作,也在效率上獲得了一定程度的提高。使用tf.data.Dataset模塊的pipline機制,可實現CPU多線程處理輸入的數據,如讀取圖片和圖片的一些的預處理,這樣GPU可以專注於訓練過程,而CPU去准備數據。
參考資料:
https://blog.csdn.net/u014061630/article/details/80776975 (五星推薦)TensorFlow全新的數據讀取方式:Dataset API入門教程:http://baijiahao.baidu.com/s?id=1583657817436843385&wfr=spider&for=pc
從tfrecord文件創建TFRecordDataset方式如下:
# 用dataset讀取TFRecords文件 dataset = tf.contrib.data.TFRecordDataset(input_file)
解析tfrecord 文件的每條記錄,即序列化后的 tf.train.Example;使用 tf.parse_single_example 來解析:
feats = tf.parse_single_example(serial_exmp, features=data_dict)
其中,data_dict 是一個dict,包含的key 是寫入tfrecord文件時用的key ,相應的value是對應不同的數據類型,我們直接使用代碼看,如下:
def _parse_record(example_photo):
features = {
'name': tf.FixedLenFeature((), tf.string),
'shape': tf.FixedLenFeature([3], tf.int64),
'data': tf.FixedLenFeature((), tf.string)
}
parsed_features = tf.parse_single_example(example_photo,features=features)
return parsed_features
解析tfrecord文件中的所有記錄,我們需要使用dataset 的map 方法,如下:
dataset = dataset.map(_parse_record)
Dataset支持一類特殊的操作:Transformation。一個Dataset通過Transformation變成一個新的Dataset。通常我們可以通過Transformation完成數據變換,打亂,組成batch,生成epoch等一系列操作。常用的Transformation有:map、batch、shuffle和repeat。
map方法可以接受任意函數對dataset中的數據進行處理;另外可以使用repeat,shuffle,batch方法對dataset進行重復,混洗,分批;用repeat賦值dataset以進行多個epoch;如下:
dataset = dataset.repeat(epochs).shuffle(buffer_size).batch(batch_size)
解析完數據后,便可以取出數據進行使用,通過創建iterator來進行,如下:
iterator = dataset.make_one_shot_iterator() features = sess.run(iterator.get_next())
下面分別介紹
8.1,map
使用 tf.data.Dataset.map,我們可以很方便地對數據集中的各個元素進行預處理。因為輸入元素之間時獨立的,所以可以在多個 CPU 核心上並行地進行預處理。map 變換提供了一個 num_parallel_calls參數去指定並行的級別。
dataset = dataset.map(map_func=parse_fn, num_parallel_calls=FLAGS.num_parallel_calls)
8.2,prefetch
tf.data.Dataset.prefetch 提供了 software pipelining 機制。該函數解耦了 數據產生的時間 和 數據消耗的時間。具體來說,該函數有一個后台線程和一個內部緩存區,在數據被請求前,就從 dataset 中預加載一些數據(進一步提高性能)。prefech(n) 一般作為最后一個 transformation,其中 n 為 batch_size。 prefetch 的使用方法如下:
dataset = dataset.batch(batch_size=FLAGS.batch_size) dataset = dataset.prefetch(buffer_size=FLAGS.prefetch_buffer_size) # last transformation return dataset
8.3,repeat
repeat的功能就是將整個序列重復多次,主要用來處理機器學習中的epoch,假設原先的數據是一個epoch,使用repeat(5)就可以將之變成5個epoch:
如果直接調用repeat()的話,生成的序列就會無限重復下去,沒有結束,因此也不會拋出tf.errors.OutOfRangeError異常
8.4,完整代碼如下:
# -*-coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import glob
import matplotlib.pyplot as plt
width=0
height=0
def show_image(title, image):
'''
顯示圖片
:param title: 圖像標題
:param image: 圖像的數據
:return:
'''
# plt.figure("show_image")
# print(image.dtype)
plt.imshow(image)
plt.axis('on') # 關掉坐標軸為 off
plt.title(title) # 圖像題目
plt.show()
def tf_read_image(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
image = tf.cast(image_decoded, tf.float32)
if width>0 and height>0:
image = tf.image.resize_images(image, [height, width])
image = tf.cast(image, tf.float32) * (1. / 255.0) # 歸一化
return image, label
def input_fun(files_list, labels_list, batch_size, shuffle=True):
'''
:param files_list:
:param labels_list:
:param batch_size:
:param shuffle:
:return:
'''
# 構建數據集
dataset = tf.data.Dataset.from_tensor_slices((files_list, labels_list))
if shuffle:
dataset = dataset.shuffle(100)
dataset = dataset.repeat() # 空為無限循環
dataset = dataset.map(tf_read_image, num_parallel_calls=4) # num_parallel_calls一般設置為cpu內核數量
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(2) # software pipelining 機制
return dataset
if __name__ == '__main__':
data_dir = 'dataset/image/*.jpg'
# labels_list = tf.constant([0,1,2,3,4])
# labels_list = [1, 2, 3, 4, 5]
files_list = glob.glob(data_dir)
labels_list = np.arange(len(files_list))
num_sample = len(files_list)
batch_size = 1
dataset = input_fun(files_list, labels_list, batch_size=batch_size, shuffle=False)
# 需滿足:max_iterate*batch_size <=num_sample*num_epoch,否則越界
max_iterate = 3
with tf.Session() as sess:
iterator = dataset.make_initializable_iterator()
init_op = iterator.make_initializer(dataset)
sess.run(init_op)
iterator = iterator.get_next()
for i in range(max_iterate):
images, labels = sess.run(iterator)
show_image("image", images[0, :, :, :])
print('shape:{},tpye:{},labels:{}'.format(images.shape, images.dtype, labels))
9,AttributeError: module 'tensorflow' has no attribute 'data' 解決方法
當我們使用tf 中的 dataset時,可能會出現如下錯誤:

原因是tf 版本不同導致的錯誤。
在編寫代碼的時候,使用的tf版本不同,可能導致其Dataset API 放置的位置不同。當使用TensorFlow1.3的時候,Dataset API是放在 contrib 包里面,而當使用TensorFlow1.4以后的版本,Dataset API已經從contrib 包中移除了,而變成了核心API的一員。故會產生報錯。
解決方法:
將下面代碼:
# 用dataset讀取TFRecords文件 dataset = tf.data.TFRecordDataset(input_file)
改為此代碼:
# 用dataset讀取TFRecords文件 dataset = tf.contrib.data.TFRecordDataset(input_file)
問題解決。
10,tf.gfile.FastGfile()函數學習
函數如下:
tf.gfile.FastGFile(path,decodestyle)
函數功能:實現對圖片的讀取
函數參數:path:圖片所在路徑
decodestyle:圖片的解碼方式(‘r’:UTF-8編碼; ‘rb’:非UTF-8編碼)
例子如下:
img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
11,Python zip()函數學習
zip() 函數用於將可迭代的對象作為參數,將對象中對應的元素打包成一個個元組,然后返回由這些元組組成的列表。如果各個迭代器的元素個數不一致,則返回列表長度與最短的對象相同,利用*號操作符,可以將元組解壓為列表。
在 Python 3.x 中為了減少內存,zip() 返回的是一個對象。如需展示列表,需手動 list() 轉換。
zip([iterable, ...]) 參數說明: iterabl——一個或多個迭代器 返回值:返回元組列表
實例:
>>>a = [1,2,3] >>> b = [4,5,6] >>> c = [4,5,6,7,8] >>> zipped = zip(a,b) # 打包為元組的列表 [(1, 4), (2, 5), (3, 6)] >>> zip(a,c) # 元素個數與最短的列表一致 [(1, 4), (2, 5), (3, 6)] >>> zip(*zipped) # 與 zip 相反,*zipped 可理解為解壓,返回二維矩陣式 [(1, 2, 3), (4, 5, 6)]
12,下一步計划
1,為什么前面使用Dataset,而用大多數博文中的 QueueRunner 呢?
A:這是因為 Dataset 比 QueueRunner 新,而且是官方推薦的,Dataset 比較簡單。
2,學習了 TFRecord 相關知識,下一步學習什么?
A:可以嘗試將常見的數據集如 MNIST 和 CIFAR-10 轉換成 TFRecord 格式。
參考文獻:https://blog.csdn.net/u012759136/article/details/52232266
https://blog.csdn.net/tengxing007/article/details/56847828/
https://blog.csdn.net/briblue/article/details/80789608 (五星推薦)
https://blog.csdn.net/happyhorizion/article/details/77894055 (五星推薦)
