Keras Data augmentation(數據擴充)


       在深度學習中,我們經常需要用到一些技巧(比如將圖片進行旋轉,翻轉等)來進行data augmentation, 來減少過擬合。 在本文中,我們將主要介紹如何用深度學習框架keras來自動的進行data augmentation。

keras.preprocessing.image.ImageDataGenerator(featurewise_center=False,
    samplewise_center=False,
    featurewise_std_normalization=False,
    samplewise_std_normalization=False,
    zca_whitening=False,
    zca_epsilon=1e-6,
    rotation_range=0.,
    width_shift_range=0.,
    height_shift_range=0.,
    shear_range=0.,
    zoom_range=0.,
    channel_shift_range=0.,
    fill_mode='nearest',
    cval=0.,
    horizontal_flip=False,
    vertical_flip=False,
    rescale=None,
    preprocessing_function=None,
    data_format=K.image_data_format())
  • 生成批次的帶實時數據增益的張量圖像數據。數據將按批次無限循環。
  • 參數
    • featurewise_center: 布爾值。將輸入數據的均值設置為 0,逐特征進行。
    • samplewise_center: 布爾值。將每個樣本的均值設置為 0。
    • featurewise_std_normalization: 布爾值。將輸入除以數據標准差,逐特征進行。
    • samplewise_std_normalization: 布爾值。將每個輸入除以其標准差。
    • zca_epsilon: ZCA 白化的 epsilon 值,默認為 1e-6。
    • zca_whitening: 布爾值。應用 ZCA 白化。
    • rotation_range: 整數。隨機旋轉的度數范圍。
    • width_shift_range: 浮點數(總寬度的比例)。隨機水平移動的范圍。
    • height_shift_range: 浮點數(總高度的比例)。隨機垂直移動的范圍。
    • shear_range: 浮點數。剪切強度(以弧度逆時針方向剪切角度)。
    • zoom_range: 浮點數 或 [lower, upper]。隨機縮放范圍。如果是浮點數,[lower, upper] = [1-zoom_range, 1+zoom_range]
    • channel_shift_range: 浮點數。隨機通道轉換的范圍。
    • fill_mode: {"constant", "nearest", "reflect" or "wrap"} 之一。輸入邊界以外的點根據給定的模式填充:
      • "constant": kkkkkkkk|abcd|kkkkkkkk (cval=k)
      • "nearest": aaaaaaaa|abcd|dddddddd
      • "reflect": abcddcba|abcd|dcbaabcd
      • "wrap": abcdabcd|abcd|abcdabcd
    • cval: 浮點數或整數。用於邊界之外的點的值,當 fill_mode = "constant" 時。
    • horizontal_flip: 布爾值。隨機水平翻轉。
    • vertical_flip: 布爾值。隨機垂直翻轉。
    • rescale: 重縮放因子。默認為 None。如果是 None 或 0,不進行縮放,否則將數據乘以所提供的值(在應用任何其他轉換之前)。
    • preprocessing_function: 應用於每個輸入的函數。這個函數會在任何其他改變之前運行。這個函數需要一個參數:一張圖像(秩為 3 的 Numpy 張量),並且應該輸出一個同尺寸的 Numpy 張量。
    • data_format: {"channels_first", "channels_last"} 之一。"channels_last" 模式表示輸入尺寸應該為 (samples, height, width, channels),"channels_first" 模式表示輸入尺寸應該為 (samples, channels, height, width)。默認為 在 Keras 配置文件 ~/.keras/keras.json 中的 image_data_format 值。如果你從未設置它,那它就是 "channels_last"。
       
  • 方法:
  • fit(x): 根據一組樣本數據,計算與數據相關轉換有關的內部數據統計信息。當且僅當 featurewise_center 或 featurewise_std_normalization 或 zca_whitening 時才需要。
  • flow(x, y): 傳入 Numpy 數據和標簽數組,生成批次的 增益的/標准化的 數據。在生成的批次數據上無限制地無限次循環。
  • flow_from_directory(directory): 以目錄路徑為參數,生成批次的 增益的/標准化的 數據。在生成的批次數據上無限制地無限次循環。
from keras.preprocessing.image import ImageDataGenerator,array_to_img,img_to_array,load_img

datagen=ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

img=load_img("test.jpg")
x=img_to_array(img) # 把PIL圖像格式轉換成numpy格式
x=x.reshape((1,)+x.shape)

i=0
for batch in datagen.flow(x,batch_size=2,save_to_dir="datagen",save_prefix="cat",save_format="jpeg"):
    i+=1
    if i>10:
        break

其他注意api:

compile

compile(self, optimizer, loss, metrics=None, loss_weights=None, sample_weight_mode=None, weighted_metrics=None, target_tensors=None) 

用於配置訓練模型。

fit

fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None) 

以固定數量的輪次(數據集上的迭代)訓練模型。

fit_generator

fit_generator(self, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0) 

使用 Python 生成器逐批生成的數據,按批次訓練模型。

evaluate

evaluate(self, x=None, y=None, batch_size=None, verbose=1, sample_weight=None, steps=None) 

在測試模式下返回模型的誤差值和評估標准值。

evaluate_generator

evaluate_generator(self, generator, steps=None, max_queue_size=10, workers=1, use_multiprocessing=False) 

在數據生成器上評估模型。

predict

predict(self, x, batch_size=None, verbose=0, steps=None) 

為輸入樣本生成輸出預測。

predict_generator

predict_generator(self, generator, steps=None, max_queue_size=10, workers=1, use_multiprocessing=False, verbose=0) 

為來自數據生成器的輸入樣本生成預測。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM