Tensorflow踩坑系列---TFRECORD文件讀寫


補充:TFRECORD文件學習

https://blog.csdn.net/briblue/article/details/80789608

import tensorflow as tf
import os
import random
import sys
#生成的tfrecord文件數量
_NUM_BLOCK = 2
#源圖片位置
DATASET_DIR = "./Images/SourceImgs/"
#目標文件位置
GOALSET_DIR = "./Images/TFImgs/"

GOALTFSET_DIR = "./Images/TFSourceImgs/"

QUEUE_DIR = "./Images/QueueImgs/"

一:生成TFRECORD文件

(一)獲取圖片信息

#獲取圖片的信息
def get_file_info(dataSet_dir = DATASET_DIR):
    files_name = []
    for filename in os.listdir(dataSet_dir):
        files_name.append(os.path.join(dataSet_dir,filename))
        
    return files_name

(二)寫入TFRECORD文件

with tf.Session() as sess:
    files_list = get_file_info()
    num_per_block = len(files_list)//_NUM_BLOCK
    for _id in range(_NUM_BLOCK):
        tfr_name = "image_%d.tfrecord"%(_id+1)
        tfr_dir = os.path.join(GOALSET_DIR,tfr_name)
        with tf.python_io.TFRecordWriter(tfr_dir) as writer:
            start_idx = _id*num_per_block
            end_idx = min((_id+1)*num_per_block,len(files_list))
            
            for i in range(start_idx,end_idx):
                try:
                    sys.stdout.write("\r>>Converting images %d/%d to block %d"%(i+1,len(files_list),_id+1))
                    sys.stdout.flush()
                    #讀取圖片信息
                    image_data = tf.gfile.FastGFile(files_list[i],'rb').read()
                    #獲取標簽
                    label = files_list[i].split("/")[-1].split(".")[0]
                    
                    example = _format_record(image_data,label)
                    
                    writer.write(example.SerializePartialToString())
                except IOError as e:
                    print("Could not read:",files_list[i])
                    print("Error",e)
                    print("Skip it\n")
    sys.stdout.write("\n")
    sys.stdout.flush()

二:直接讀取TFRECORD文件

(一)解析文件

def _parse_record(example_proto):
    features = {
        'label':tf.FixedLenFeature((),tf.string),
        'data':tf.FixedLenFeature((),tf.string)
    }
    parsed_features = tf.parse_single_example(example_proto,features=features)
    return parsed_features

(二)讀取所有文件

with tf.Session() as sess:
    tf_files = []
    for fn in os.listdir(GOALSET_DIR):
        tf_files.append(os.path.join(GOALSET_DIR,fn))
        
    dataSet = tf.data.TFRecordDataset(tf_files) #讀取TF文件---可以選擇一次性讀取所有的tfrecord文件
    dataSet = dataSet.map(_parse_record) #解析數據
    
    iterator = dataSet.make_one_shot_iterator()
    
    sess.run(tf.local_variables_initializer())
    while True:
        try:
            Singledata = sess.run(iterator.get_next())
            label = Singledata['label'].decode()
            image_data = Singledata['data']
            tf.gfile.GFile(os.path.join(GOALTFSET_DIR,"%s.jpg"%label),"wb").write(image_data)
        except BaseException as e:
            print("Read finish!!!")
            break

三:使用文件隊列讀取多個tfrecord文件

 

tf_files = []
for fn in os.listdir(GOALSET_DIR):
    tf_files.append(os.path.join(GOALSET_DIR,fn))

#string_input_producer產生文件名隊列
filename_queue = tf.train.string_input_producer(tf_files,shuffle=True,num_epochs=3) #獲取了多個tfrecord文件

#reader從文件名隊列中讀取數據
reader = tf.TFRecordReader()
key,value = reader.read(filename_queue) #返回文件名和文件內容
features = tf.parse_single_example(value,features={
        'label':tf.FixedLenFeature((),tf.string),
        'data':tf.FixedLenFeature((),tf.string)
    })
img_data = features['data']
label = features['label']

image_batch,label_batch = tf.train.shuffle_batch([img_data,label],batch_size=8,num_threads=2,allow_smaller_final_batch=True,
                                                capacity=500,min_after_dequeue=100) 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) #初始化上面的全局變量
    sess.run(tf.local_variables_initializer()) #初始化上面的局部變量
    
    coord = tf.train.Coordinator()
    #啟動start_queue_runners之后,才會開始填充隊列
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)
    j = 1
    try:
        while not coord.should_stop():
            images_data,labels_data = sess.run([image_batch,label_batch])
            for i in range(len(images_data)):
                with open(QUEUE_DIR+"%s-%d.jpg"%(labels_data[i].decode(),j),"wb") as f:
                    f.write(images_data[i])
                j+=1
    except BaseException as e:
            print("read all files")
    finally:
        coord.request_stop() #將讀取文件的線程關閉
    coord.join(threads) #線程回收,將讀取文件的子線程加入主線程

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM