tensorflow二進制文件讀取與tfrecords文件讀取


1、知識點

"""
TFRecords介紹:
    TFRecords是Tensorflow設計的一種內置文件格式,是一種二進制文件,它能更好的利用內存,
    更方便復制和移動,為了將二進制數據和標簽(訓練的類別標簽)數據存儲在同一個文件中

CIFAR-10批處理結果存入tfrecords流程:
    1、構造存儲器
         a)TFRecord存儲器API:tf.python_io.TFRecordWriter(path) 寫入tfrecords文件
            參數:   
                path: TFRecords文件的路徑
                return:寫文件
            方法:
                write(record):向文件中寫入一個字符串記錄
                    record:字符串為一個序列化的Example,Example.SerializeToString()
                close():關閉文件寫入器

    2、構造每一個樣本的Example協議塊
         a)tf.train.Example(features=None)寫入tfrecords文件
                features:tf.train.Features類型的特征實例
                return:example格式協議塊

         b)tf.train.Features(feature=None)構建每個樣本的信息鍵值對
                feature:字典數據,key為要保存的名字,
                value為tf.train.Feature實例
                return:Features類型

         c)tf.train.Feature(**options)
                **options:例如
                    bytes_list=tf.train.BytesList(value=[Bytes])
                    int64_list=tf.train.Int64List(value=[Value])
                數據類型:
                    tf.train.Int64List(value=[Value])
                    tf.train.BytesList(value=[Bytes]) 
                    tf.train.FloatList(value=[value]) 

    3、寫入序列化的Example
         writer.write(example.SerializeToString())
   
報錯: 
        1、ValueError: Protocol message Feature has no "Bytes_list" field.
                因為沒有Bytes_list屬性字段,只有bytes_list字段
                
讀取tfrecords流程:
    1、構建文件隊列
        file_queue = tf.train.string_input_producer([FLAGS.cifar_tfrecords])
    2、構造TFRecords閱讀器
        reader = tf.TFRecordReader()
    3、解析Example,獲取數據
        a) tf.parse_single_example(serialized,features=None,name=None)解析TFRecords的example協議內存塊
            serialized:標量字符串Tensor,一個序列化的Example
            features:dict字典數據,鍵為讀取的名字,值為FixedLenFeature
            return:一個鍵值對組成的字典,鍵為讀取的名字
        b)tf.FixedLenFeature(shape,dtype) 類型只能是float32,int64,string
            shape:輸入數據的形狀,一般不指定,為空列表
            dtype:輸入數據類型,與存儲進文件的類型要一致     
    4、轉換格式,bytes解碼
        image = tf.decode_raw(features["image"],tf.uint8)
        #固定圖像大小,有利於批處理操作
        image_reshape = tf.reshape(image,[self.height,self.width,self.channel])
        label = tf.cast(features["label"],tf.int32)
    5、批處理
        image_batch , label_batch = tf.train.batch([image_reshape,label],batch_size=5,num_threads=1,capacity=20)

報錯:
    1、ValueError: Shape () must have rank at least 1
        
"""

2、代碼

# coding = utf-8
import tensorflow as tf
import  os

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("cifar_dir","./cifar10/", "文件的目錄")
tf.app.flags.DEFINE_string("cifar_tfrecords", "./tfrecords/cifar.tfrecords", "存進tfrecords的文件")
class CifarRead(object):
    """
    完成讀取二進制文件,寫進tfrecords,讀取tfrecords
    """
    def __init__(self,file_list):
        self.file_list = file_list
        #圖片屬性
        self.height = 32
        self.width = 32
        self.channel = 3

        #二進制字節
        self.label_bytes = 1
        self.image_bytes = self.height*self.width*self.channel
        self.bytes = self.label_bytes + self.image_bytes


    def read_and_encode(self):
        """
        讀取二進制文件,並進行解碼操作
        :return:
        """
        #1、創建文件隊列
        file_quque = tf.train.string_input_producer(self.file_list)
        #2、創建閱讀器,讀取二進制文件
        reader = tf.FixedLengthRecordReader(self.bytes)
        key, value = reader.read(file_quque)#key為文件名,value為文件內容
        #3、解碼操作
        label_image = tf.decode_raw(value,tf.uint8)

        #分割圖片和標簽數據, tf.cast(),數據類型轉換   tf.slice()tensor數據進行切片
        label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32)
        image = tf.slice(label_image,[self.label_bytes],[self.image_bytes])

        #對圖像進行形狀改變
        image_reshape = tf.reshape(image,[self.height,self.width,self.channel])

        # 4、批處理操作
        image_batch , label_batch = tf.train.batch([image_reshape,label],batch_size=5,num_threads=1,capacity=20)
        print(image_batch,label_batch)
        return image_batch,label_batch

    def write_ro_tfrecords(self,image_batch,label_batch):
        """
        將讀取的二進制文件寫入 tfrecords文件中
        :param image_batch: 圖像 (32,32,3)
        :param label_batch: 標簽
        :return:
        """
        # 1、構造存儲器
        writer = tf.python_io.TFRecordWriter(FLAGS.cifar_tfrecords)

        #循環寫入
        for i in range(5):
            image = image_batch[i].eval().tostring()
            label = int(label_batch[i].eval()[0])
            # 2、構造每一個樣本的Example
            example = tf.train.Example(features=tf.train.Features(feature={
                "image":tf.train.Feature(bytes_list = tf.train.BytesList(value = [image])) ,
                "label":tf.train.Feature(int64_list = tf.train.Int64List(value = [label])) ,
            }))

            # 3、寫入序列化的Example
            writer.write(example.SerializeToString())

        #關閉流
        writer.close()
        return None

    def read_from_tfrecords(self):
        """
        從tfrecords文件讀取數據
        :return:
        """
        #1、構建文件隊列
        file_queue = tf.train.string_input_producer([FLAGS.cifar_tfrecords])
        #2、構造TFRecords閱讀器
        reader = tf.TFRecordReader()
        key , value = reader.read(file_queue)
        #3、解析Example
        features = tf.parse_single_example(value,features={
            "image":tf.FixedLenFeature([],tf.string),
            "label":tf.FixedLenFeature([],tf.int64)
        })
        #4、解碼內容, 如果讀取的內容格式是string需要解碼, 如果是int64,float32不需要解碼
        image = tf.decode_raw(features["image"],tf.uint8)
        #固定圖像大小,有利於批處理操作
        image_reshape = tf.reshape(image,[self.height,self.width,self.channel])
        label = tf.cast(features["label"],tf.int32)

        #5、批處理
        image_batch , label_batch = tf.train.batch([image_reshape,label],batch_size=5,num_threads=1,capacity=20)
        return image_batch,label_batch


if __name__ == '__main__':
    #################二進制文件讀取###############
    # file_name = os.listdir(FLAGS.cifar_dir)
    # file_list = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"]
    # cf = CifarRead(file_list)
    # image_batch, label_batch = cf.read_and_encode()
    # with tf.Session() as sess:
    #     # 創建協調器
    #     coord = tf.train.Coordinator()
    #     # 開啟線程
    #     threads = tf.train.start_queue_runners(sess, coord=coord)
    #
    #     print(sess.run([image_batch, label_batch]))
    #     # 回收線程
    #     coord.request_stop()
    #     coord.join(threads)
    #############################################

    #####二進制文件讀取,並寫入tfrecords文件######
    # file_name = os.listdir(FLAGS.cifar_dir)
    # file_list = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"]
    # cf = CifarRead(file_list)
    # image_batch, label_batch = cf.read_and_encode()
    # with tf.Session() as sess:
    #     # 創建協調器
    #     coord = tf.train.Coordinator()
    #     # 開啟線程
    #     threads = tf.train.start_queue_runners(sess, coord=coord)
    #     #########保存文件到tfrecords##########
    #     cf.write_ro_tfrecords(image_batch, label_batch)
    #     #########保存文件到tfrecords##########
    #
    #     print(sess.run([image_batch, label_batch]))
    #     # 回收線程
    #     coord.request_stop()
    #     coord.join(threads)
    ##############################################

    #############從tfrecords文件讀取###############
    file_name = os.listdir(FLAGS.cifar_dir)
    file_list = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"]
    cf = CifarRead(file_list)
    image_batch, label_batch = cf.read_from_tfrecords()
    with tf.Session() as sess:
        # 創建協調器
        coord = tf.train.Coordinator()
        # 開啟線程
        threads = tf.train.start_queue_runners(sess, coord=coord)

        print(sess.run([image_batch, label_batch]))
        # 回收線程
        coord.request_stop()
        coord.join(threads)
    ##############################################

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM