1、知識點
""" TFRecords介紹: TFRecords是Tensorflow設計的一種內置文件格式,是一種二進制文件,它能更好的利用內存, 更方便復制和移動,為了將二進制數據和標簽(訓練的類別標簽)數據存儲在同一個文件中 CIFAR-10批處理結果存入tfrecords流程: 1、構造存儲器 a)TFRecord存儲器API:tf.python_io.TFRecordWriter(path) 寫入tfrecords文件 參數: path: TFRecords文件的路徑 return:寫文件 方法: write(record):向文件中寫入一個字符串記錄 record:字符串為一個序列化的Example,Example.SerializeToString() close():關閉文件寫入器 2、構造每一個樣本的Example協議塊 a)tf.train.Example(features=None)寫入tfrecords文件 features:tf.train.Features類型的特征實例 return:example格式協議塊 b)tf.train.Features(feature=None)構建每個樣本的信息鍵值對 feature:字典數據,key為要保存的名字, value為tf.train.Feature實例 return:Features類型 c)tf.train.Feature(**options) **options:例如 bytes_list=tf.train.BytesList(value=[Bytes]) int64_list=tf.train.Int64List(value=[Value]) 數據類型: tf.train.Int64List(value=[Value]) tf.train.BytesList(value=[Bytes]) tf.train.FloatList(value=[value]) 3、寫入序列化的Example writer.write(example.SerializeToString()) 報錯: 1、ValueError: Protocol message Feature has no "Bytes_list" field. 因為沒有Bytes_list屬性字段,只有bytes_list字段 讀取tfrecords流程: 1、構建文件隊列 file_queue = tf.train.string_input_producer([FLAGS.cifar_tfrecords]) 2、構造TFRecords閱讀器 reader = tf.TFRecordReader() 3、解析Example,獲取數據 a) tf.parse_single_example(serialized,features=None,name=None)解析TFRecords的example協議內存塊 serialized:標量字符串Tensor,一個序列化的Example features:dict字典數據,鍵為讀取的名字,值為FixedLenFeature return:一個鍵值對組成的字典,鍵為讀取的名字 b)tf.FixedLenFeature(shape,dtype) 類型只能是float32,int64,string shape:輸入數據的形狀,一般不指定,為空列表 dtype:輸入數據類型,與存儲進文件的類型要一致 4、轉換格式,bytes解碼 image = tf.decode_raw(features["image"],tf.uint8) #固定圖像大小,有利於批處理操作 image_reshape = tf.reshape(image,[self.height,self.width,self.channel]) label = tf.cast(features["label"],tf.int32) 5、批處理 image_batch , label_batch = tf.train.batch([image_reshape,label],batch_size=5,num_threads=1,capacity=20) 報錯: 1、ValueError: Shape () must have rank at least 1 """
2、代碼
# coding = utf-8 import tensorflow as tf import os FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string("cifar_dir","./cifar10/", "文件的目錄") tf.app.flags.DEFINE_string("cifar_tfrecords", "./tfrecords/cifar.tfrecords", "存進tfrecords的文件") class CifarRead(object): """ 完成讀取二進制文件,寫進tfrecords,讀取tfrecords """ def __init__(self,file_list): self.file_list = file_list #圖片屬性 self.height = 32 self.width = 32 self.channel = 3 #二進制字節 self.label_bytes = 1 self.image_bytes = self.height*self.width*self.channel self.bytes = self.label_bytes + self.image_bytes def read_and_encode(self): """ 讀取二進制文件,並進行解碼操作 :return: """ #1、創建文件隊列 file_quque = tf.train.string_input_producer(self.file_list) #2、創建閱讀器,讀取二進制文件 reader = tf.FixedLengthRecordReader(self.bytes) key, value = reader.read(file_quque)#key為文件名,value為文件內容 #3、解碼操作 label_image = tf.decode_raw(value,tf.uint8) #分割圖片和標簽數據, tf.cast(),數據類型轉換 tf.slice()tensor數據進行切片 label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32) image = tf.slice(label_image,[self.label_bytes],[self.image_bytes]) #對圖像進行形狀改變 image_reshape = tf.reshape(image,[self.height,self.width,self.channel]) # 4、批處理操作 image_batch , label_batch = tf.train.batch([image_reshape,label],batch_size=5,num_threads=1,capacity=20) print(image_batch,label_batch) return image_batch,label_batch def write_ro_tfrecords(self,image_batch,label_batch): """ 將讀取的二進制文件寫入 tfrecords文件中 :param image_batch: 圖像 (32,32,3) :param label_batch: 標簽 :return: """ # 1、構造存儲器 writer = tf.python_io.TFRecordWriter(FLAGS.cifar_tfrecords) #循環寫入 for i in range(5): image = image_batch[i].eval().tostring() label = int(label_batch[i].eval()[0]) # 2、構造每一個樣本的Example example = tf.train.Example(features=tf.train.Features(feature={ "image":tf.train.Feature(bytes_list = tf.train.BytesList(value = [image])) , "label":tf.train.Feature(int64_list = tf.train.Int64List(value = [label])) , })) # 3、寫入序列化的Example writer.write(example.SerializeToString()) #關閉流 writer.close() return None def read_from_tfrecords(self): """ 從tfrecords文件讀取數據 :return: """ #1、構建文件隊列 file_queue = tf.train.string_input_producer([FLAGS.cifar_tfrecords]) #2、構造TFRecords閱讀器 reader = tf.TFRecordReader() key , value = reader.read(file_queue) #3、解析Example features = tf.parse_single_example(value,features={ "image":tf.FixedLenFeature([],tf.string), "label":tf.FixedLenFeature([],tf.int64) }) #4、解碼內容, 如果讀取的內容格式是string需要解碼, 如果是int64,float32不需要解碼 image = tf.decode_raw(features["image"],tf.uint8) #固定圖像大小,有利於批處理操作 image_reshape = tf.reshape(image,[self.height,self.width,self.channel]) label = tf.cast(features["label"],tf.int32) #5、批處理 image_batch , label_batch = tf.train.batch([image_reshape,label],batch_size=5,num_threads=1,capacity=20) return image_batch,label_batch if __name__ == '__main__': #################二進制文件讀取############### # file_name = os.listdir(FLAGS.cifar_dir) # file_list = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"] # cf = CifarRead(file_list) # image_batch, label_batch = cf.read_and_encode() # with tf.Session() as sess: # # 創建協調器 # coord = tf.train.Coordinator() # # 開啟線程 # threads = tf.train.start_queue_runners(sess, coord=coord) # # print(sess.run([image_batch, label_batch])) # # 回收線程 # coord.request_stop() # coord.join(threads) ############################################# #####二進制文件讀取,並寫入tfrecords文件###### # file_name = os.listdir(FLAGS.cifar_dir) # file_list = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"] # cf = CifarRead(file_list) # image_batch, label_batch = cf.read_and_encode() # with tf.Session() as sess: # # 創建協調器 # coord = tf.train.Coordinator() # # 開啟線程 # threads = tf.train.start_queue_runners(sess, coord=coord) # #########保存文件到tfrecords########## # cf.write_ro_tfrecords(image_batch, label_batch) # #########保存文件到tfrecords########## # # print(sess.run([image_batch, label_batch])) # # 回收線程 # coord.request_stop() # coord.join(threads) ############################################## #############從tfrecords文件讀取############### file_name = os.listdir(FLAGS.cifar_dir) file_list = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"] cf = CifarRead(file_list) image_batch, label_batch = cf.read_from_tfrecords() with tf.Session() as sess: # 創建協調器 coord = tf.train.Coordinator() # 開啟線程 threads = tf.train.start_queue_runners(sess, coord=coord) print(sess.run([image_batch, label_batch])) # 回收線程 coord.request_stop() coord.join(threads) ##############################################