TF Boys (TensorFlow Boys ) 養成記(二): TensorFlow 數據讀取


TensorFlow 的 How-Tos,講解了這么幾點:

1. 變量:創建,初始化,保存,加載,共享;

2. TensorFlow 的可視化學習,(r0.12版本后,加入了Embedding Visualization)

3. 數據的讀取;

4. 線程和隊列;

5. 分布式的TensorFlow;

6. 增加新的Ops;

7. 自定義數據讀取;

由於各種原因,本人只看了前5個部分,剩下的2個部分還沒來得及看,時間緊任務重,所以匆匆發車了,以后如果有用到的地方,再回過頭來研究。學習過程中深感官方文檔的繁雜冗余極多多,特別是第三部分數據讀取,又臭又長,花了我好久時間,所以我想把第三部分整理如下,方便乘客們。

TensorFlow 有三種方法讀取數據:1)供給數據,用placeholder;2)從文件讀取;3)用常量或者是變量來預加載數據,適用於數據規模比較小的情況。供給數據沒什么好說的,前面已經見過了,不難理解,我們就簡單的說一下從文件讀取數據。

官方的文檔里,從文件讀取數據是一段很長的描述,鏈接層出不窮,看完這個鏈接還沒看幾個字,就出現了下一個鏈接。

自己花了很久才認識路,所以想把這部分總結一下,帶帶我的乘客們。

 

首先要知道你要讀取的文件的格式,選擇對應的文件讀取器;

然后,定位到數據文件夾下,用

["file0", "file1"]        # or 
[("file%d" % i) for i in range(2)])    # or 
tf.train.match_filenames_once

選擇要讀取的文件的名字,用 tf.train.string_input_producer 函數來生成文件名隊列,這個函數可以設置shuffle = Ture,來打亂隊列,可以設置epoch = 5,過5遍訓練數據。

最后,選擇的文件讀取器,讀取文件名隊列並解碼,輸入 tf.train.shuffle_batch 函數中,生成 batch 隊列,傳遞給下一層。

 

1)假如你要讀取的文件是像 CSV 那樣的文本文件,用的文件讀取器和解碼器就是 TextLineReaderdecode_csv

2)假如你要讀取的數據是像 cifar10 那樣的 .bin 格式的二進制文件,就用 tf.FixedLengthRecordReadertf.decode_raw 讀取固定長度的文件讀取器和解碼器。如下列出了我的參考代碼,后面會有詳細的解釋,這邊先大致了解一下:

class cifar10_data(object):
    def __init__(self, filename_queue):
        self.height = 32
        self.width = 32
        self.depth = 3
        self.label_bytes = 1
        self.image_bytes = self.height * self.width * self.depth
        self.record_bytes = self.label_bytes + self.image_bytes
        self.label, self.image = self.read_cifar10(filename_queue)
        
    def read_cifar10(self, filename_queue):
        reader = tf.FixedLengthRecordReader(record_bytes = self.record_bytes)
        key, value = reader.read(filename_queue)
        record_bytes = tf.decode_raw(value, tf.uint8)
        label = tf.cast(tf.slice(record_bytes, [0], [self.label_bytes]), tf.int32)
        image_raw = tf.slice(record_bytes, [self.label_bytes], [self.image_bytes])
        image_raw = tf.reshape(image_raw, [self.depth, self.height, self.width])
        image = tf.transpose(image_raw, (1,2,0))        
        image = tf.cast(image, tf.float32)
        return label, image
        
def inputs(data_dir, batch_size, train = True, name = 'input'):

    with tf.name_scope(name):
        if train:    
            filenames = [os.path.join(data_dir,'data_batch_%d.bin' % ii) 
                        for ii in range(1,6)]
            for f in filenames:
                if not tf.gfile.Exists(f):
                    raise ValueError('Failed to find file: ' + f)
                    
            filename_queue = tf.train.string_input_producer(filenames)
            read_input = cifar10_data(filename_queue)
            images = read_input.image
            images = tf.image.per_image_whitening(images)
            labels = read_input.label
            num_preprocess_threads = 16
            image, label = tf.train.shuffle_batch(
                                    [images,labels], batch_size = batch_size, 
                                    num_threads = num_preprocess_threads, 
                                    min_after_dequeue = 20000, capacity = 20192)
        
            
            return image, tf.reshape(label, [batch_size])
            
        else:
            filenames = [os.path.join(data_dir,'test_batch.bin')]
            for f in filenames:
                if not tf.gfile.Exists(f):
                    raise ValueError('Failed to find file: ' + f)
                    
            filename_queue = tf.train.string_input_producer(filenames)
            read_input = cifar10_data(filename_queue)
            images = read_input.image
            images = tf.image.per_image_whitening(images)
            labels = read_input.label
            num_preprocess_threads = 16
            image, label = tf.train.shuffle_batch(
                                    [images,labels], batch_size = batch_size, 
                                    num_threads = num_preprocess_threads, 
                                    min_after_dequeue = 20000, capacity = 20192)
        
            
            return image, tf.reshape(label, [batch_size])
    

 

3)如果你要讀取的數據是圖片,或者是其他類型的格式,那么可以先把數據轉換成 TensorFlow 的標准支持格式 tfrecords ,它其實是一種二進制文件,通過修改 tf.train.Example 的Features,將 protocol buffer 序列化為一個字符串,再通過 tf.python_io.TFRecordWriter 將序列化的字符串寫入 tfrecords,然后再用跟上面一樣的方式讀取tfrecords,只是讀取器變成了tf.TFRecordReader,之后通過一個解析器tf.parse_single_example ,然后用解碼器 tf.decode_raw 解碼。

 

例如,對於生成式對抗網絡GAN,我采用了這個形式進行輸入,部分代碼如下,后面會有詳細解釋,這邊先大致了解一下:

def _int64_feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
def _bytes_feature(value):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
    

def convert_to(data_path, name):
    
    """
    Converts s dataset to tfrecords
    """
    
    rows = 64
    cols = 64
    depth = DEPTH
    for ii in range(12):
        writer = tf.python_io.TFRecordWriter(name + str(ii) + '.tfrecords')
        for img_name in os.listdir(data_path)[ii*16384 : (ii+1)*16384]:
            img_path = data_path + img_name
            img = Image.open(img_path)
            h, w = img.size[:2]
            j, k = (h - OUTPUT_SIZE) / 2, (w - OUTPUT_SIZE) / 2
            box = (j, k, j + OUTPUT_SIZE, k+ OUTPUT_SIZE)
            
            img = img.crop(box = box)
            img = img.resize((rows,cols))
            img_raw = img.tobytes()
            example = tf.train.Example(features = tf.train.Features(feature = {
                                    'height': _int64_feature(rows),
                                    'weight': _int64_feature(cols),
                                    'depth': _int64_feature(depth),
                                    'image_raw': _bytes_feature(img_raw)}))
            writer.write(example.SerializeToString())
        writer.close()


def read_and_decode(filename_queue):
    
    """
    read and decode tfrecords
    """
    
#    filename_queue = tf.train.string_input_producer([filename_queue])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    
    features = tf.parse_single_example(serialized_example,features = {
                        'image_raw':tf.FixedLenFeature([], tf.string)})
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    
    return image

這里,我的data_path下面有16384*12張圖,通過12次寫入Example操作,把圖片數據轉化成了12個tfrecords,每個tfrecords里面有16384張圖。

 

4)如果想定義自己的讀取數據操作,請參考

好了,今天的車到站了,請帶好隨身物品准備下車,明天老司機還有一趟車,請記得准時乘坐,車不等人。

 

 

參考文獻:

1.

2. 沒了

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM