tensorflow學習筆記--dataset使用,創建自己的數據集


數據讀入需求

我們在訓練模型參數時想要從訓練數據集中一次取出一小批數據(比如50條、100條)做梯度下降,不斷地分批取出數據直到損失函數基本不再減小並且在訓練集上的正確率足夠高,取出的n條數據還要是預處理過的,一次取出的要包含輸入數據和對應的lable,並且希望在達到訓練效果之前可以不斷地取出數據而不會因數據集取空了提前結束訓練,最好取出的數據還是亂序的。

基於上面的要求,我們可以利用TensorFlow的dataset模塊創建我們所需的數據集。

Dataset簡介

TensorFlow程序數據導入的方法有多種。一是通過 feed_dict 傳入具體值。二是利用tf的Queues創建數據隊列,一次取出batch個數據進行訓練,隊列可以用多線程讀數據,速度比較快,但是隊列模塊的用法比較復雜,要修改程序的時候就感覺很亂。

Dataset與隊列相比就簡單多了,Dataset(數據集) API 在 TensorFlow 1.4版本中已經從tf.contrib.data遷移到了tf.data之中,增加了對於Python的生成器的支持,官方強烈建議使用Dataset API 為 TensorFlow模型創建輸入管道。

dataset用法

import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))

創建了一個dataset,這個dataset中含有5個元素1….,5,為了將5個元素取出,方法是從Dataset中示例化一個iterator,然后對iterator進行迭代。

iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
    for i in range(5):
        print(sess.run(one_element))    

語句iterator = dataset.make_one_shot_iterator()從dataset中實例化了一個Iterator,這個Iterator是一個“one shot iterator”,即只能從頭到尾讀取一次。one_element = iterator.get_next()表示從iterator里取出一個元素。這里取5次后dataset里的元素就空了,再取的話就就會拋出tf.errors.OutOfRangeError異常。

除了one-hot iterator,tf還支持其他三種iterator

  • initializable
  • reinitializable
  • feedable

這三個迭代器比one-hot復雜,這里就不介紹他們了。

 

dataset元素變換

dataset數據集API還有一些操作元素的函數來滿足我們的對輸入數據的需求。

  • map
  • shuffle
  • batch
  • repeat

1. map

map接收一個函數,Dataset中的每個元素都會被當作這個函數的輸入,並將函數返回值作為新的Dataset,如我們可以對dataset中每個元素的值加1:

def add1(x):
    return x+1

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) dataset = dataset.map(add1)

2. shuffle

shuffle的功能為打亂dataset中的元素,它有一個參數buffersize,表示打亂時使用的buffer的大小:

dataset = dataset.shuffle(500)

3. batch

使用一次iterator返回一批數據的數量:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
dataset = dataset.batch(2)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess: for i in range(10): print(sess.run(one_element)) # 這樣就一次獲取兩個數,可以取3次,第三次取到一個數

4. repeat

上面的代碼取3次數就取完了,再取得話就會拋出異常,如果想重復取數,可以用dataset.repeat(count),count的值表示將全部的數在dataset中重復幾次:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
dataset = dataset.batch(2).repeat(2)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.next()
with tf.Session() as sess:
    for i in range(10):
        print(sess.run(one_element))

這樣就將5個數重復了兩遍。這里需要注意的一點是它雖然重復了兩次,但並不是可以取5次,一次取兩個數,而是:[1,2], [3,4] , [5],  [1,2], [3,4] , [5] 。這樣再取到數據集末尾的時候得到的數據數量不是我們設置的batch_size 條數據。要想重復取數並且每次得到的都是batch_size條數據,可以設置batch_size的大小能被總數據量整除。

repeat()中的參數如果是None,則可以無限取數。

 

讀入圖片和lable,創建自己的數據集

import tensorflow as tf
import os

batch_size = 50
img_resize = [100,100]
epoch_num = None   # dataset.repeat() 的參數,設置為None,可以不斷取數

# 傳入圖片名,返回正則化后的圖片的像素值
def read_img(img_name, lable): image = tf.read_file(img_name) image = tf.image.decode_jpeg(image) image = tf.image.resize_images(image, img_resize) image = tf.image.per_image_standardization(image) return image,lable
# 傳入圖片所在的文件夾,圖片名含有圖片的lable,返回利用文件夾中圖片創建的dataset
def create_dataset(path): files = os.listdir(path) # 列出文件夾中所有的圖片 img_names = [] lables = [] for f in files: img_names.append(os.path.join(path,f)) # 圖片的完整路徑append到文件名list中 lable = f.split('.')[0] lables.append([int(i) for i in lable]) # 根據規則得到圖片的lable img_names = tf.convert_to_tensor(img_names, dtype=tf.string) lables = tf.convert_to_tensor(lables, dtype=tf.float32) # 將圖片名list和lable的list轉換成Tensor類型
dataset
= tf.data.Dataset.from_tensor_slices((img_names,lables)) # 創建dataset,傳入的需要是tensor類型
dataset
= dataset.map(read_img) # 傳入read_img函數,將圖片名轉為像素
  
  # 將dataset打亂,設置一次獲取batch_size條數據 dataset
= dataset.shuffle(buffer_size=800).batch(batch_size).repeat(epoch_num)
return dataset
dataset
= create_dataset('./img') # 圖片所在的路徑為./img iterator = dataset.make_one_shot_iterator() one_element = iterator.get_next() # 創建dataset是batch_size 為多少這里一次就能獲取多少個數據

在程序中,sess.run(one_element) 一次就能獲取到batch_size條數據和對應的lable

 

參考鏈接

https://blog.csdn.net/ssmixi/article/details/80572813

https://www.jianshu.com/p/d80ea5d73446


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM