簡單處理API
讀取圖像:
image.imdecode(open('../img/cat1.jpg', 'rb').read())
圖像類型轉換:
img.astype('float32')
圖像增強流程
具體增強方式教程有很詳細的示意,不再贅述
輔助函數,用於將增強函數應用於單張圖片:
def apply_aug_list(img, augs):
for f in augs:
img = f(img)
return img
對於訓練圖片我們隨機水平翻轉和剪裁。對於測試圖片僅僅就是中心剪裁。我們假設剪裁成28×28×3用於輸入網絡:
train_augs = [
image.HorizontalFlipAug(.5),
image.RandomCropAug((28,28))
]
test_augs = [
image.CenterCropAug((28,28))
]
使用如下閉包來增強:
def get_transform(augs):
def transform(data, label):
# data: sample x height x width x channel
# label: sample
data = data.astype('float32')
if augs is not None:
# apply to each sample one-by-one and then stack
data = nd.stack(*[
apply_aug_list(d, augs) for d in data])
data = nd.transpose(data, (0,3,1,2))
return data, label.astype('float32')
return transform
基本邏輯就是這樣。
