【小白學PyTorch】17 TFrec文件的創建與讀取


【新聞】:機器學習煉丹術的粉絲的人工智能交流群已經建立,目前有目標檢測、醫學圖像、時間序列等多個目標為技術學習的分群和水群嘮嗑的總群,歡迎大家加煉丹兄為好友,加入煉丹協會。微信:cyx645016617.

參考目錄:

本文的代碼已經上傳公眾號后台,回復【PyTorch】獲取。
第一次接觸到TFrec文件,我也是比較蒙蔽的其實:

可以看到文件是.tfrec后綴的,而且先記住這個文件是186.72MB大小的。

1 為什么用tfrec文件

正常情況下我們用於訓練的文件夾內部往往會存着成千上萬的圖片或文本等文件,這些文件通常被散列存放。這種存儲方式有一些缺點:

  • 占用磁盤空間;
  • 一個一個讀取文件消耗時間

而tfrec格式的文件存儲形式會很合理的幫我們存儲數據,核心就是tfrec內部使用Protocol Buffer的二進制數據編碼方案,這個方案可以極大的壓縮存儲空間

之前我們知道一個tfrec文件100多M,這是因為這個tfrec文件內存儲了很多的圖片,類似於壓縮,對tfrec解壓縮后可以獲取到一部分的數據集,當我們把全部的rfrec文件都解壓縮后,可以獲取到全部的數據集。

值得一提的是,rfrec文件內除了可以存儲圖片,還可以存儲其他的數據,比方說圖片的label。字符串,float類型等都可以轉換成二進制的方法,所以什么數據類型基本上都可以存儲到rfrec文件內,從而簡化讀取數據的過程。

2 tfrec文件的內部結構

tfrec文件時tensorflow的數據集存儲格式,tensorflow可以高效的讀取和處理這些數據集,因此我見過有的數據集因為是tfrec文件,所以用TF讀取數據集,然后用pytorch訓練模型。

之前提到了tfrec文件里面是有多個樣本的,所以tfrec可以為是多個tf.train.Example文件組成的序列(每一個example是一個樣本),然后每一個tf.train.Example又是由若干個tf.train.Features字典組成。這個Features可以理解為這個樣本的一些信息,如果是圖片樣本,那么肯定有一個Features是圖片像素值數據,一個Features是圖片的標簽值;如果是預測任務,那么這個Feature可能就是一些字符串類型的特征

3 制作tfrec文件

import tensorflow as tf
import glob
# 先記錄一下要保存的tfrec文件的名字
tfrecord_file = './train.tfrec'
# 獲取指定目錄的所有以jpeg結尾的文件list
images = glob.glob('./*.jpeg')
with tf.io.TFRecordWriter(tfrecord_file) as writer:
    for filename in images:
        image = open(filename, 'rb').read()  # 讀取數據集圖片到內存,image 為一個 Byte 類型的字符串
        feature = {  # 建立 tf.train.Feature 字典
            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),  # 圖片是一個 Bytes 對象
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[1])),
            'float':tf.train.Feature(float_list=tf.train.FloatList(value=[1.0,2.0])),
            'name':tf.train.Feature(bytes_list=tf.train.BytesList(value=[str.encode(filename)]))
        }
        # tf.train.Example 在 tf.train.Features 外面又多了一層封裝
        example = tf.train.Example(features=tf.train.Features(feature=feature))  # 通過字典建立 Example
        writer.write(example.SerializeToString())  # 將 Example 序列化並寫入 TFRecord 文件

代碼中我們需要注意的地方是:

  • 先讀取圖片,然后構建一個字典來作為這個example的格式;
  • 上面代碼中,字典中有四個屬性,首先是image圖片本身的像素值,然后有一個標簽,標簽是int類型,然后有一個float浮點類型,name是一個字符串類型,這個string類型的需要轉換成byte字節類型的才能進行存儲,所以這里使用str.encode來把字符串轉換成字節;
  • 然后這個features再經過Example的封裝,再然后把這個example寫進這個tfrec文件中。

這一段代碼建議保存下來,方便以后的直接參考和復制。構建tfrec文件對於tensorflow處理圖片來說,應該是繞不過的一個步驟。

4 讀取tfrec文件

現在,我們運行完上面的代碼,應該生成了一個./train.tfrec文件,下面我們再對這個文件進行讀取。

import tensorflow as tf

dataset = tf.data.TFRecordDataset('./train.tfrec')

def decode(example):
    feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),
        'float': tf.io.FixedLenFeature([1, 2], tf.float32),
        'name': tf.io.FixedLenFeature([], tf.string)
    }
    feature_dict = tf.io.parse_single_example(example, feature_description)
    feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image'])  # 解碼 JEPG 圖片
    return feature_dict

dataset = dataset.map(decode).batch(4)
for i in dataset.take(1):
    print(i['image'].shape)
    print(i['label'].shape)
    print(i['float'].shape)
    print(bytes.decode(i['name'][0].numpy()))
  • 首先使用專門用來讀取tfrec文件的方法tf.data.TFRecordDataset,進行讀取,創建了一個dataset,但是這個dataset並不能直接使用,需要對tfrec中的example進行一些解碼;
  • 自己寫一個解碼函數decode,首先寫一個特征描述,我們知道在保存tfrec的時候每一個example有四個特征,這里需要對每一個特征確定他的類型,是string還是int還是float這樣的。
  • 然后通過這個特征描述和tf.io.parse_single_example方法,從example中提取到對應的特征;
  • 因為image是一個圖片張量,而我們讀取的時候是讀取的tf.string的類型,所以使用tf.io.decode_jpeg()來把字符串解碼成一個tensor張量。
  • 最后使用上節課講過的.batch(4)把數據集每一個batch包含四個樣本。

上面代碼輸出的結果為:

需要注意的是這個如何把name轉換成string類型的,如果已經在本地跑完了上面的代碼,可以自己看看i['name']是一個什么類型的,然后自己試試如何轉換成string類型的。上面的代碼是能成功轉換的。

下一次的內容就是如何構建模型,然后怎么把數據集喂給模型。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM