關於Tensorflow讀取數據,官網給出了三種方法:
- 供給數據(Feeding): 在TensorFlow程序運行的每一步, 讓Python代碼來供給數據。
- 從文件讀取數據: 在TensorFlow圖的起始, 讓一個輸入管線從文件中讀取數據。
- 預加載數據: 在TensorFlow圖中定義常量或變量來保存所有數據(僅適用於數據量比較小的情況)。
在使用Tensorflow訓練數據時,第一步為准備數據,現在我們只討論圖像數據。其數據讀取大致分為:原圖讀取、二進制文件讀取、tf標准存儲文件讀取。
一、原圖文件讀取
很多情況下我們的圖片訓練集就是原始圖片本身,並沒有像cifar dataset那樣存成bin等格式。我們需要根據一個train_list列表,去挨個讀取圖片。這里我用到的方法是首先獲取image list和labellist,然后讀入隊列中,那么對每次dequeue的內容中可以提取當前圖片的路勁和label。
1、獲取文件列表
def get_image_list(fileDir): imageList = [] labelList = [] filelist = os.listdir(fileDir) for var in filelist: imagename = os.path.join(fileDir, var) label = int(os.path.basename(var).split('_')[0]) imageList.append(imagename) labelList.append(label) return imageList, labelList
上述程序是從指定目錄中獲取文件列表和標簽,其我的文件為
總共15個文件,’_’前為文件標簽,記得要轉化為int類型,否則后面程序或報錯。
2、將文件列表加載到內存列表中,並進行讀取
步驟分為:
a、列表轉化為tensor類型,並存到內存中
b、 從內存列表中讀取數據,進行獲取圖像和label
c、 根據訓練要求對數據進行轉化
d、利用batch獲取批次文件
def input_data_imageslist_slice(fileDir): # 獲取文件列表 imageList , labelList = get_image_list(fileDir) # 將文件列表和標簽列表轉為為tensor,進而能存入內存列表中,記得label在上面轉為int,否則下面會出錯,這是相對應的 imagesTensor = tf.convert_to_tensor(imageList, dtype = tf.string) labelsTensor = tf.convert_to_tensor(labelList, dtype = tf.uint8) # 從內存列表中讀取文件,此處只讀取一個文件,並記錄文件位置 queue = tf.train.slice_input_producer([imagesTensor, labelsTensor]) # 提取圖片內容和標簽內容,一定注意數據之間的轉化; image_content = tf.read_file(queue[0]) imageData = tf.image.decode_jpeg(image_content,channels=3) #channels必須要制定,當時沒指定,程序報錯 imageData = tf.image.convert_image_dtype(imageData,tf.uint8) # 圖片數據進行轉化,此處為了顯示而轉化 labelData = tf.cast(queue[1],tf.uint8) # show_single_data(imageData, labelData) #根據數據訓練尺寸,調整圖片大小,此處設置為32*32 new_size = tf.constant([IMAGE_WIDTH,IMAGE_WIDTH], dtype=tf.int32) image = tf.image.resize_images(imageData, new_size) # 這是數據提取關鍵,因為設置了batch_size,決定了每次提取數據的個數,比如此處是3,則每次為3個文件 imageBatch, labelBatch = tf.train.shuffle_batch([image, labelData], batch_size = BATCH_SIZE, capacity = 2000,min_after_dequeue = 1000) return imageBatch, labelBatch
3、文件測試
在文件測試中,必須添加 threads = tf.train.start_queue_runners(sess = sess),會話窗口才會從內存堆棧中讀取數據。
def test_record(filename): image_batch, label_batch = input_data_imageslist_slice(filename) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) threads = tf.train.start_queue_runners(sess = sess) for i in range(5): val, label = sess.run([image_batch, label_batch]) print(val.shape, label)
我在程序中設置了batchsize為3,所以shape為(3,32,32,)后面則是label,此處很好的讀取了數據
(3, 32, 32, 3) [12 3 22] (3, 32, 32, 3) [ 7 2 22] (3, 32, 32, 3) [ 2 5 15] (3, 32, 32, 3) [12 5 13] (3, 32, 32, 3) [12 6 10]
4、校驗對比
為了更好的得知shuffle_batch是否讓文件和label對應,程序中進行了修改
image = tf.image.resize_images(imageData, new_size)
修改為:
image = tf.cast(queue[0],tf.string)
還有
print(val.shape, label)
print(val, label)
結果為:
[b'E:\\010_test_tensorflow\\02_produce_data\\images1\\1_1.jpg' b'E:\\010_test_tensorflow\\02_produce_data\\images1\\12_03.jpg' b'E:\\010_test_tensorflow\\02_produce_data\\images1\\10_2.jpg'] [ 1 12 10] [b'E:\\010_test_tensorflow\\02_produce_data\\images1\\2_2.jpg' b'E:\\010_test_tensorflow\\02_produce_data\\images1\\22_9.jpg' b'E:\\010_test_tensorflow\\02_produce_data\\images1\\33_0.jpg'] [ 2 22 33] [b'E:\\010_test_tensorflow\\02_produce_data\\images1\\33_0.jpg' b'E:\\010_test_tensorflow\\02_produce_data\\images1\\7_0.jpg' b'E:\\010_test_tensorflow\\02_produce_data\\images1\\13_06.jpg'] [33 7 13]
我們發現不但image和label相對應,而且還打亂了順序,真的是很完美啊。
二、TFRecords讀取
對於數據量較小而言,可能一般選擇直接將數據加載進內存,然后再分batch
輸入網絡進行訓練(tip:使用這種方法時,結合yield
使用更為簡潔,大家自己嘗試一下吧,我就不贅述了)。但是,如果數據量較大,這樣的方法就不適用了,因為太耗內存,所以這時最好使用tensorflow提供的隊列queue
,也就是第二種方法 從文件讀取數據。對於一些特定的讀取,比如csv文件格式,官網有相關的描述,在這兒我介紹一種比較通用,高效的讀取方法(官網介紹的少),即使用tensorflow內定標准格式——TFRecords。
FRecords其實是一種二進制文件,雖然它不如其他格式好理解,但是它能更好的利用內存,更方便復制和移動,並且不需要單獨的標簽文件。
TFRecords文件包含了tf.train.Example
協議內存塊(protocol buffer)(協議內存塊包含了字段 Features
)。我們可以寫一段代碼獲取你的數據, 將數據填入到Example
協議內存塊(protocol buffer),將協議內存塊序列化為一個字符串, 並且通過tf.python_io.TFRecordWriter
寫入到TFRecords文件。
從TFRecords文件中讀取數據, 可以使用tf.TFRecordReader
的tf.parse_single_example
解析器。這個操作可以將Example
協議內存塊(protocol buffer)解析為張量。
2.1 生成TFRecords文件
class SaveRecord(object): def __init__(self,recordDir, fileDir, imageSize): self._imageSize = imageSize trainRecord = os.path.join(recordDir,'train.tfrecord') validRecord = os.path.join(recordDir,'valid.tfrecord') # 獲取文件列表 filenames = os.listdir(fileDir) np.random.shuffle(filenames) fileNum = len(filenames) print('the count of images is ' + str(fileNum)) # 獲取訓練和測試樣本,比例為4:1 splitNum = int(fileNum * 0.8) trainImages = filenames[ : splitNum] validImages = filenames[splitNum : ] # 保存數據到制定位置 self.save_data_to_record( fileDir = fileDir, datas = trainImages, recordname = trainRecord) self.save_data_to_record(fileDir = fileDir,datas = validImages, recordname = validRecord) def save_data_to_record(self,fileDir, datas, recordname): writer = tf.python_io.TFRecordWriter(recordname) for var in datas: filename = os.path.join(fileDir, var) label = int(os.path.basename(var).split('_')[0]) image = Image.open(filename) # 打開圖片 image = image.resize((self._imageSize,self._imageSize)) imageArray = image.tobytes() #轉為bytes example = tf.train.Example(features = tf.train.Features(feature = { 'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [imageArray])) ,'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [label]))})) writer.write(example.SerializeToString()) writer.close()
編程為一個類,其核心代碼在於save_data_to_records,其主要流程為:
- 初始化寫入器writer,來源於tf.python_io.TFRecordWriter。
- 遍歷傳入的數據,可以為文件名,意味后面二進制解析也是文件名
- 解析文件名,獲取label,這是之前處理好的
- 利用IPL的Image讀入圖像數據,預處理數據:調整大小,且轉為化二值化數據
- 利用tf中的Example中獲取數據,原理是利用字典對應關系,獲取features,當然里面有點繞,仔細讀讀全是在類型轉化而已
- example二進制化,然后寫入。
- 關閉寫入器
其中里面關鍵點:圖片bytes的轉化,以及example的賦值。
基本的,一個Example
中包含Features
,Features
里包含Feature
(這里沒s)的字典。最后,Feature
里包含有一個 FloatList
, 或者ByteList
,或者Int64List。
2.2 讀取record文件
for serialized_example in tf.python_io.tf_record_iterator("train.tfrecords"): example = tf.train.Example() example.ParseFromString(serialized_example) image = example.features.feature['image'].bytes_list.value label = example.features.feature['label'].int64_list.value # 可以做一些預處理之類的 print image, label
上面為一個解析文件的一個例子,主要是利用example直接進行解析,簡單,但是這樣比較耗內存,常用的方法是利用文件隊列讀取。
即利用string_input_produce,結合tf.recordreader進行數據讀取,最后進行解析,其例子為:
def read_and_decode(filename): #根據文件名生成一個隊列 filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized = reader.read(filename_queue) #返回文件名和文件 features = tf.parse_single_example(serialized = serialized, features = { 'image' : tf.FixedLenFeature([], tf.string), 'label' : tf.FixedLenFeature([], tf.int64)}) image = tf.decode_raw(features['image'], tf.uint8) image= tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, 3]) # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 # image = tf.cast(features['image'], tf.string) label = tf.cast(features['label'], tf.int32) img_batch, label_batch = tf.train.shuffle_batch([image, label], batch_size=BATCH_SIZE, capacity=2000, min_after_dequeue=1000) return img_batch, label_batch
其流程為:
- record文件放入文件隊列中
- 初始化話RecordReader,發現這個reader和writer初始化方式不一樣
- 從隊列中讀取數據,記得返回兩個值,我們只要第二個,此時數據為二進制數據(前面我們存入的二進制數據)
- 根據約定解析數據,類型即為之前存儲的格式。
- 利用tf的轉化獲取image和label
- 關鍵一步就是tf.train.shuffle_batch,利用此函數可以批量獲取數據,當然是在文件列表中。
此處對文件列表中數據讀取的過程中,我們發現讀取器是不一樣的。比如此次是讀取record的內存文件,代碼為:
filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized = reader.read(filename_queue) #返回文件名和文件
而之前從文件列表和數據列表讀取的時候為:
imageList , labelList = get_image_list(fileDir) queue = tf.train.string_input_producer(imageList) reader = tf.WholeFileReader() _, image_content = reader.read(queue)
而我們使用的slice_input_produce的時候,變成了tf.read_file,一定記得各個的不同。
# 從內存列表中讀取文件,此處只讀取一個文件,並記錄文件位置 queue = tf.train.slice_input_producer([imagesTensor, labelsTensor]) # 提取圖片內容和標簽內容,一定注意數據之間的轉化; image_content = tf.read_file(queue[0]) imageData = tf.image.decode_jpeg(image_content,channels=3)
2.3 測試數據
前面我們解析了shuffle_batch的好處,此處我們即檢測是否讀取了數據。
def test_record(filename): image_batch, label_batch = read_and_decode(filename) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) threads = tf.train.start_queue_runners(sess = sess) for i in range(5): val, label = sess.run([image_batch, label_batch]) print(val.shape, label)
此時的輸出結果為:
(10, 16, 16, 3) [15 7 33 5 4 10 13 7 1 3] (10, 16, 16, 3) [33 22 10 10 4 12 7 4 13 10] (10, 16, 16, 3) [10 10 12 7 5 15 12 22 15 5] (10, 16, 16, 3) [10 10 5 1 12 10 3 5 33 3] (10, 16, 16, 3) [12 4 7 15 4 7 4 13 5 10]
結果表明有效的對數據進行了讀取。