將圖片數據轉化為TFRecord格式與讀取


將圖片數據轉化為TFRecord格式與讀取

一、問題情景描述

  目錄下有一個叫做“Original”的文件夾,文件夾里有十個子文件,分別命名為1,2···一直到10(為了做10輪取平均),這10個子文件夾里還有四個子文件夾,分別命名為“train0”,"train1","test0","test1"。其中含義如其命名所示。這四個子文件夾里一共有若干張JPG格式圖像數據。現欲將這份圖像數據轉化為TFRecord格式,用來做CNN訓練。

二、實現代碼

# 導入相關庫
import os import tensorflow as tf import numpy as np from PIL import Image

轉為為TFRecord格式代碼:

for i in range(1, 11):             # 用來表示文件夾1到10
 cwd = 'Original/'+str(i)+'/'                           # 第i個文件夾路徑
    path_tfrecord = 'Original_tfrecord/'+str(i)+'/'        # tfrecord文件路徑
    
    if not os.path.exists(path_tfrecord): os.makedirs(path_tfrecord) print(path_tfrecord+" 開始轉換") else: print(path_tfrecord+" 開始轉換") #f = open(path_tfrecord+'fileQueue', 'w') # 用寫的方式打開fileQueue這個文件,並賦給f
    with open(path_tfrecord+'fileQueue', 'w') as f: # 創建一個writer來寫 TFRecords 文件
        writer1 = tf.python_io.TFRecordWriter(path_tfrecord+"train.tfrecords") writer2 = tf.python_io.TFRecordWriter(path_tfrecord+"test.tfrecords") class_path1 = cwd + 'train0' + '/' class_path2 = cwd + 'train1' + '/' class_path3 = cwd + 'test0' + '/' class_path4 = cwd + 'test1' + '/'

        # os.listdir返回指定的文件夾包含的文件或文件夾的名字的列表,它不包括 '.' 和'..'
        for img in os.listdir(class_path1): # print(img)
            f.writelines(img + 'train0' + '\n') img_path = class_path1 + img     # 每張圖片的地址
            # 讀取img文件
            img_raw = Image.open(img_path).convert('L') img_raw = img_raw.resize((28, 28))     # 轉換圖片大小
            img_raw_new = img_raw.tobytes()       # 將圖片轉化為原生bytes
            
            # tf.train.Example來定義我們要填入的數據格式,然后使用tf.python_io.TFRecordWriter來寫入
            example = tf.train.Example( # 一個Example中包含Features,Features里包含Feature(這里沒s)的字典。最后,Feature里包含有一個 FloatList, 
                # 或者ByteList,或者Int64List
                features=tf.train.Features( feature={ # example對象對label和image數據進行封裝
                        "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[0])), "img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw_new]))})) writer1.write(example.SerializeToString()) # 序列化為字符串

        for img in os.listdir(class_path2): # print(img)
            f.writelines(img + 'train1' + '\n') img_path = class_path2 + img img_raw = Image.open(img_path).convert('L') img_raw = img_raw.resize((28, 28)) img_raw_new = img_raw.tobytes() example = tf.train.Example( features=tf.train.Features( feature={ "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[1])), "img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw_new]))})) writer1.write(example.SerializeToString()) writer1.close() for img in os.listdir(class_path3): # print(img)
            f.writelines(img + 'test0' + '\n') img_path = class_path3 + img img_raw = Image.open(img_path).convert('L') img_raw = img_raw.resize((28, 28)) img_raw_new = img_raw.tobytes() example = tf.train.Example( features=tf.train.Features( feature={ "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[0])), "img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw_new]))})) writer2.write(example.SerializeToString()) for img in os.listdir(class_path4): # print(img)
            f.writelines(img + 'test1' + '\n') img_path = class_path4 + img img_raw = Image.open(img_path).convert('L') img_raw = img_raw.resize((28, 28)) img_raw_new = img_raw.tobytes() example = tf.train.Example( features=tf.train.Features( feature={ "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[1])), "img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw_new]))})) writer2.write(example.SerializeToString()) writer2.close() #f.close()
    print("結束")

定義解析TFRecord函數

# 生成了TFRecords文件,接下來就可以使用隊列(queue)讀取數據了
def read_and_decode11(filename): filename_queue = tf.train.string_input_producer([filename])  # 根據文件名生成一個隊列
    reader = tf.TFRecordReader()                                 # 定義一個 reader ,讀取下一個 record
    _, serialized_example = reader.read(filename_queue) # 解析讀入的一個record
    features = tf.parse_single_example( serialized_example, features={"label": tf.FixedLenFeature([], tf.int64), "img_raw": tf.FixedLenFeature([], tf.string)}) img = tf.decode_raw(features["img_raw"], np.int8)           # 將字符串解析成圖像對應的像素組
    img = tf.reshape(img, [28 * 28 * 1]) # img = tf.reshape(img,[28,28,1])
    img = tf.cast(img, tf.float32) * (1. / 255) label = tf.cast(features["label"], tf.int32) return img, label

調用上述函數即可在代碼中使用解析好的數據

for i in range(1,11: path_tfrecord = 'Original_tfrecord/'+str(i)+'/' img_train, label_train = read_and_decode11(path_tfrecord+"train.tfrecords") img_test, label_test = read_and_decode11(path_tfrecord+"test.tfrecords") label_train = tf.one_hot(indices=tf.cast(label1, tf.int32), depth=2)  # 將一個值化為一個概率分布的向量
    label_test = tf.one_hot(indices=tf.cast(label2, tf.int32), depth=2) # 隨機打亂生成batch
    img_batch_train, label_batch_train = tf.train.shuffle_batch([img_train, label_train], batch_size=64, capacity=1000, min_after_dequeue=500) img_batch_test, label_batch_test = tf.train.shuffle_batch([img_test, label_test], batch_size=13, capacity=13, min_after_dequeue=0)

三、未完待續

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM