tensorflow數據集加載


本篇涉及的內容主要有小型常用的經典數據集的加載步驟,tensorflow提供了如下接口:keras.datasets、tf.data.Dataset.from_tensor_slices(shuffle、map、batch、repeat),涉及的數據集如下:boston housing、mnist/fashion mnist、cifar10/100、imdb

1.keras.datasets

通過該接口可以直接下載指定數據集。boston housing提供了和房價有關的一些因子(面積、居民來源等),mnist提供了手寫數字的圖片和對應label,fashion mnist提供了10種衣服的灰度圖和對應label,cifar10/100是用來進行簡單圖像識別的數據集,分別包含10類物品和100類物品,imdb是一個類似於淘寶好評的數據集,即通過評語及其標注(好評或差評),來實現一個好評或差評的分類器。

注:通過該接口得到的數據集格式為numpy格式。

2.tf.data.Dataset.from_tensor_slices()

該方法可以用來進行數據的迭代,過程中可以直接將numpy格式轉化為tensor格式,然后通過調用next(iter())方法實現迭代,使用示例如下:

# 加載數據集
(x,y),(x_test,y_test) = keras.datasets.mnist.load_data()
# 轉化為tensor並實現迭代
db = tf.data.Dataset.from_tensor_slices(x_test)
# 打印迭代數據的shape
print(next(iter(db)).shape)
# 將img和label封裝為同一次迭代
db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
print(next(iter(db))[0].shape)
print(next(iter(db))[1].shape)

 

3.shuffle

通過shuffle函數可以將數據集打散,從而提高模型的泛化能力,使用方法:db.shuffle(10000),參數設置范圍,通常值設置比較大

4.map

# deep learning一般使用float32,而numpy格式多為float64,所以需要轉化
def preprocess(x,y):
    x = tf.cast(x,dtype=tf.float32)/255
    y = tf.cast(y,dtype=tf.int32)
    y = tf.one_hot(y,depth=10)
    return x,y

db2 = db.map(preprocess)
res = next(iter(db2))
print(res[0].shape,res[1].shape)

 

5.batch

db3 = db2.batch(32)
res = next(iter(db3))
print(res[0].shape,res[1].shape)

6.StopIteration

因為迭代多次后會到達數據集的末尾,如果不進行異常處理則會報StopIteration異常,如下處理方式就是錯誤的:

db_iter = iter(db3)
while True:
    next(db_iter)

 

只要加上異常處理語句對db_iter重新賦值即可

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM