圖像分類任務中,大多數教程是直接導入深度學習庫中的數據集直接用於模型訓練,如果采用自己的數據集,會難以下手,這篇博客主要介紹使用Tensorflow2.1或Keras來讀取自己的數據集。
1、Tensorflow方法制作數據集
Tensorflow制作數據集,主要用到tf.data進行操作。步驟為制作csv文件、讀取csv、讀取數據、數據處理。
需要用到的庫
import os import random import glob import csv import tensorflow as tf
1.1 制作csv文件
# 創建csv文件,輸入分別為路徑和要創建的csv文件名 def build_csv(root, filename): # 對種類進行編號,相當於用0,1,n-1分別表示分類任務的類別 name2label = {} for name in sorted(os.listdir(os.path.join(root))): # 判斷文件夾下的對象是否是一個文件夾 # 不是文件夾,直接進行下一次判斷 # 是文件夾,對該目錄進行編號 if not os.path.isdir(os.path.join(root, name)): continue name2label[name] = len(name2label.keys()) # 准備從每個文件夾中讀取圖片路徑與編號 images = [] # 遍歷數據集中的每個文件夾 for name in name2label.keys(): # 讀取所有的png,jpg,jpeg格式的文件 images += glob.glob(os.path.join(root, name, '*.png')) images += glob.glob(os.path.join(root, name, '*.jpg')) images += glob.glob(os.path.join(root, name, '*.jpeg')) print(len(images), images) random.shuffle(images) # 創建並寫csv文件 with open(os.path.join(root, filename), mode='w', newline='') as f: writer = csv.writer(f) for img in images: # 更改路徑的分隔符 name = img.split(os.sep)[-2] label = name2label[name] writer.writerow([img, label]) print('written into csv file:', filename)
1.2讀取csv文件
# 輸入分別為路徑和剛剛創建的csv文件名 def load_csv(root, filename): images, labels = [], [] with open(os.path.join(root, filename)) as f: reader = csv.reader(f) for row in reader: img, label = row label = int(label) images.append(img) labels.append(label) return images, labels
1.3將數據集轉換為tf.data格式
# 讀取csv文件 images, labels = load_csv(root, filename) # 轉換為tf.data格式 dataset = tf.data.Dataset.from_tensor_slices((images, labels)) # 數據處理操作,其中preprocessing是需要自己編寫的一個實現數據處理功能的函數 dataset = dataset.shuffle(1000).map(preprocess).batch(32)
1.4數據處理操作
# 輸入為路徑和標簽 def preprocess(x, y): # 根據路徑讀取圖片 x = tf.io.read_file(x) # 將圖片數值轉換為張量 x = tf.image.decode_jpeg(x, channels=3) # 更改尺寸 x = tf.image.resize(x, [244, 244]) # 歸一化 x = tf.cast(x, dtype=tf.float32) / 255. y = tf.convert_to_tensor(y) return x, y
2、Keras方法制作數據集
Keras制作數據集,使用Keras進行導入數據集。使用keras導入數據集,過程簡單方便。
需要用到的庫
from keras.preprocessing.image import ImageDataGenerator
2.1讀取數據
# 將照片[0-255]數據縮放為[0-1] train_datagen = ImageDataGenerator(rescale=1./255) test_datagen = ImageDataGenerator(rescale=1./255) # 訓練集與驗證集路徑 train_dir = "train/" validation_dir = "validation/" # 生成了224x224的RGB圖像,形狀為[20,224,224,3]與二進制標簽[20,]的批量,每個批量包含20個樣本 train_generator = train_datagen.flow_from_directory( train_dir, # 訓練集路徑 target_size=(224, 224), # 訓練集樣本尺寸大小為(224, 224) batch_size=32, # 訓練集每批包含32個樣本 class_mode='categorical') validation_generator = test_datagen.flow_from_directory( validation_dir, target_size=(224, 224), batch_size=16, class_mode='categorical')
2.2 輸入數據到模型
history = model.fit_generator( train_generator, validation_data=validation_generator,
......
)
