本篇涉及的內容主要有小型常用的經典數據集的加載步驟,tensorflow提供了如下接口:keras.datasets、tf.data.Dataset.from_tensor_slices(shuffle、map、batch、repeat),涉及的數據集如下:boston housing、mnist/fashion mnist、cifar10/100、imdb
1.keras.datasets
通過該接口可以直接下載指定數據集。boston housing提供了和房價有關的一些因子(面積、居民來源等),mnist提供了手寫數字的圖片和對應label,fashion mnist提供了10種衣服的灰度圖和對應label,cifar10/100是用來進行簡單圖像識別的數據集,分別包含10類物品和100類物品,imdb是一個類似於淘寶好評的數據集,即通過評語及其標注(好評或差評),來實現一個好評或差評的分類器。
注:通過該接口得到的數據集格式為numpy格式。
2.tf.data.Dataset.from_tensor_slices()
該方法可以用來進行數據的迭代,過程中可以直接將numpy格式轉化為tensor格式,然后通過調用next(iter())方法實現迭代,使用示例如下:
# 加載數據集 (x,y),(x_test,y_test) = keras.datasets.mnist.load_data() # 轉化為tensor並實現迭代 db = tf.data.Dataset.from_tensor_slices(x_test) # 打印迭代數據的shape print(next(iter(db)).shape) # 將img和label封裝為同一次迭代 db = tf.data.Dataset.from_tensor_slices((x_test,y_test)) print(next(iter(db))[0].shape) print(next(iter(db))[1].shape)
3.shuffle
通過shuffle函數可以將數據集打散,從而提高模型的泛化能力,使用方法:db.shuffle(10000),參數設置范圍,通常值設置比較大
4.map
# deep learning一般使用float32,而numpy格式多為float64,所以需要轉化 def preprocess(x,y): x = tf.cast(x,dtype=tf.float32)/255 y = tf.cast(y,dtype=tf.int32) y = tf.one_hot(y,depth=10) return x,y db2 = db.map(preprocess) res = next(iter(db2)) print(res[0].shape,res[1].shape)
5.batch
db3 = db2.batch(32) res = next(iter(db3)) print(res[0].shape,res[1].shape)
6.StopIteration
因為迭代多次后會到達數據集的末尾,如果不進行異常處理則會報StopIteration異常,如下處理方式就是錯誤的:
db_iter = iter(db3) while True: next(db_iter)
只要加上異常處理語句對db_iter重新賦值即可