補充:TFRECORD文件學習
https://blog.csdn.net/briblue/article/details/80789608
import tensorflow as tf import os import random import sys
#生成的tfrecord文件數量 _NUM_BLOCK = 2 #源圖片位置 DATASET_DIR = "./Images/SourceImgs/" #目標文件位置 GOALSET_DIR = "./Images/TFImgs/" GOALTFSET_DIR = "./Images/TFSourceImgs/" QUEUE_DIR = "./Images/QueueImgs/"
一:生成TFRECORD文件
(一)獲取圖片信息
#獲取圖片的信息 def get_file_info(dataSet_dir = DATASET_DIR): files_name = [] for filename in os.listdir(dataSet_dir): files_name.append(os.path.join(dataSet_dir,filename)) return files_name
(二)寫入TFRECORD文件
with tf.Session() as sess: files_list = get_file_info() num_per_block = len(files_list)//_NUM_BLOCK for _id in range(_NUM_BLOCK): tfr_name = "image_%d.tfrecord"%(_id+1) tfr_dir = os.path.join(GOALSET_DIR,tfr_name) with tf.python_io.TFRecordWriter(tfr_dir) as writer: start_idx = _id*num_per_block end_idx = min((_id+1)*num_per_block,len(files_list)) for i in range(start_idx,end_idx): try: sys.stdout.write("\r>>Converting images %d/%d to block %d"%(i+1,len(files_list),_id+1)) sys.stdout.flush() #讀取圖片信息 image_data = tf.gfile.FastGFile(files_list[i],'rb').read() #獲取標簽 label = files_list[i].split("/")[-1].split(".")[0] example = _format_record(image_data,label) writer.write(example.SerializePartialToString()) except IOError as e: print("Could not read:",files_list[i]) print("Error",e) print("Skip it\n") sys.stdout.write("\n") sys.stdout.flush()

二:直接讀取TFRECORD文件
(一)解析文件
def _parse_record(example_proto): features = { 'label':tf.FixedLenFeature((),tf.string), 'data':tf.FixedLenFeature((),tf.string) } parsed_features = tf.parse_single_example(example_proto,features=features) return parsed_features
(二)讀取所有文件
with tf.Session() as sess: tf_files = [] for fn in os.listdir(GOALSET_DIR): tf_files.append(os.path.join(GOALSET_DIR,fn)) dataSet = tf.data.TFRecordDataset(tf_files) #讀取TF文件---可以選擇一次性讀取所有的tfrecord文件 dataSet = dataSet.map(_parse_record) #解析數據 iterator = dataSet.make_one_shot_iterator() sess.run(tf.local_variables_initializer()) while True: try: Singledata = sess.run(iterator.get_next()) label = Singledata['label'].decode() image_data = Singledata['data'] tf.gfile.GFile(os.path.join(GOALTFSET_DIR,"%s.jpg"%label),"wb").write(image_data) except BaseException as e: print("Read finish!!!") break

三:使用文件隊列讀取多個tfrecord文件
tf_files = [] for fn in os.listdir(GOALSET_DIR): tf_files.append(os.path.join(GOALSET_DIR,fn)) #string_input_producer產生文件名隊列 filename_queue = tf.train.string_input_producer(tf_files,shuffle=True,num_epochs=3) #獲取了多個tfrecord文件 #reader從文件名隊列中讀取數據 reader = tf.TFRecordReader() key,value = reader.read(filename_queue) #返回文件名和文件內容 features = tf.parse_single_example(value,features={ 'label':tf.FixedLenFeature((),tf.string), 'data':tf.FixedLenFeature((),tf.string) }) img_data = features['data'] label = features['label'] image_batch,label_batch = tf.train.shuffle_batch([img_data,label],batch_size=8,num_threads=2,allow_smaller_final_batch=True, capacity=500,min_after_dequeue=100) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) #初始化上面的全局變量 sess.run(tf.local_variables_initializer()) #初始化上面的局部變量 coord = tf.train.Coordinator() #啟動start_queue_runners之后,才會開始填充隊列 threads = tf.train.start_queue_runners(sess=sess,coord=coord) j = 1 try: while not coord.should_stop(): images_data,labels_data = sess.run([image_batch,label_batch]) for i in range(len(images_data)): with open(QUEUE_DIR+"%s-%d.jpg"%(labels_data[i].decode(),j),"wb") as f: f.write(images_data[i]) j+=1 except BaseException as e: print("read all files") finally: coord.request_stop() #將讀取文件的線程關閉 coord.join(threads) #線程回收,將讀取文件的子線程加入主線程

