Python圖像處理庫 - Albumentations,可用於深度學習中網絡訓練時的圖片數據增強.
Albumentations 圖像數據增強庫特點:
- 基於高度優化的 OpenCV 庫實現圖像快速數據增強.
- 針對不同圖像任務,如分割,檢測等,超級簡單的 API 接口.
- 易於個性化定制.
- 易於添加到其它框架,比如 PyTorch.
1. Albumentations 的 pip 安裝
sudo pip install albumentations
2. 不同圖片數據增強庫對比
對 ImageNet validation set 中的前 2000 張圖片進行處理,采用 Intel Core i7-7800X CPU.
不同數據增強庫的處理速度對比(以秒為單位,時間越少越好).
3. 使用示例
https://github.com/albu/albumentations/blob/master/notebooks/example.ipynb
import numpy as np import cv2 from matplotlib import pyplot as plt from albumentations import ( HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90, Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue, IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, IAASharpen, IAAEmboss, RandomContrast, RandomBrightness, Flip, OneOf, Compose ) # 圖像變換函數 image = cv2.imread('test.jpg', 1) # BGR image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) aug = HorizontalFlip(p=1) img_HorizontalFlip = aug(image=image)['image'] aug = IAAPerspective(scale=0.2, p=1) img_IAAPerspective = aug(image=image)['image'] aug = ShiftScaleRotate(p=1) img_ShiftScaleRotate = aug(image=image)['image'] def augment_flips_color(p=.5): return Compose([ CLAHE(), RandomRotate90(), Transpose(), ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.50, rotate_limit=45, p=.75), Blur(blur_limit=3), OpticalDistortion(), GridDistortion(), HueSaturationValue() ], p=p) aug = augment_flips_color(p=1) img_augment_flips_color = aug(image=image)['image'] def strong_aug(p=.5): return Compose([ RandomRotate90(), Flip(), Transpose(), OneOf([ IAAAdditiveGaussianNoise(), GaussNoise(), ], p=0.2), OneOf([ MotionBlur(p=.2), MedianBlur(blur_limit=3, p=.1), Blur(blur_limit=3, p=.1), ], p=0.2), ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=.2), OneOf([ OpticalDistortion(p=0.3), GridDistortion(p=.1), IAAPiecewiseAffine(p=0.3), ], p=0.2), OneOf([ CLAHE(clip_limit=2), IAASharpen(), IAAEmboss(), RandomContrast(), RandomBrightness(), ], p=0.3), HueSaturationValue(p=0.3), ], p=p) aug == strong_aug(p=1) img_strong_aug = aug(image=image)['image'] # show plt.subplot(2, 3, 1) plt.imshow(image) plt.subplot(2, 3, 2) plt.imshow(img_HorizontalFlip) plt.subplot(2, 3, 3) plt.imshow(img_IAAPerspective) plt.subplot(2, 3, 4) plt.imshow(img_ShiftScaleRotate) plt.subplot(2, 3, 5) plt.imshow(img_augment_flips_color) plt.subplot(2, 3, 6) plt.imshow(img_strong_aug) plt.show()
from albumentations import ( RandomRotate90, Transpose, ShiftScaleRotate, Blur, OpticalDistortion, CLAHE, GaussNoise, MotionBlur, GridDistortion, HueSaturationValue, IAAAdditiveGaussianNoise, MedianBlur, IAAPiecewiseAffine, IAASharpen, IAAEmboss, RandomContrast, RandomBrightness, Flip, OneOf, Compose ) import numpy as np def strong_aug(p=0.5): return Compose([ RandomRotate90(), Flip(), Transpose(), OneOf([ IAAAdditiveGaussianNoise(), GaussNoise(), ], p=0.2), OneOf([ MotionBlur(p=0.2), MedianBlur(blur_limit=3, p=0.1), Blur(blur_limit=3, p=0.1), ], p=0.2), ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2), OneOf([ OpticalDistortion(p=0.3), GridDistortion(p=0.1), IAAPiecewiseAffine(p=0.3), ], p=0.2), OneOf([ CLAHE(clip_limit=2), IAASharpen(), IAAEmboss(), RandomContrast(), RandomBrightness(), ], p=0.3), HueSaturationValue(p=0.3), ], p=p) image = np.ones((300, 300, 3), dtype=np.uint8) mask = np.ones((300, 300), dtype=np.uint8) whatever_data = "my name" augmentation = strong_aug(p=0.9) data = {"image": image, "mask": mask, "whatever_data": whatever_data, "additional": "hello"} augmented = augmentation(**data) ## 數據增強 image, mask, whatever_data, additional = augmented["image"], augmented["mask"], augmented["whatever_data"], augmented["additional"]
4. 更新的使用示例
https://github.com/albu/albumentations 更新了幾個關於 albumentations 的使用 Demo.
4.1 綜合示例 - showcase
# 導入相關庫,並定義用於可視化的函數 #!--*-- coding: utf-8 --*-- import os import numpy as np import cv2 from matplotlib import pyplot as plt from skimage.color import label2rgb import albumentations as A import random BOX_COLOR = (255, 0, 0) TEXT_COLOR = (255, 255, 255) def visualize_bbox(img, bbox, color=BOX_COLOR, thickness=2, **kwargs): #height, width = img.shape[:2] x_min, y_min, w, h = bbox x_min, x_max, y_min, y_max = int(x_min), int(x_min + w), int(y_min), int(y_min + h) cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness) return img def visualize_titles(img, bbox, title, color=BOX_COLOR, thickness=2, font_thickness = 2, font_scale=0.35, **kwargs): #height, width = img.shape[:2] x_min, y_min, w, h = bbox x_min, x_max, y_min, y_max = int(x_min), int(x_min + w), int(y_min), int(y_min + h) ((text_width, text_height), _) = cv2.getTextSize(title, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness) cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), BOX_COLOR, -1) cv2.putText(img, title, (x_min, y_min - int(0.3 * text_height)), cv2.FONT_HERSHEY_SIMPLEX, font_scale, TEXT_COLOR, font_thickness, lineType=cv2.LINE_AA) return img def augment_and_show(aug, image, mask=None, bboxes=[], categories=[], category_id_to_name=[], filename=None, font_scale_orig=0.35, font_scale_aug=0.35, show_title=True, **kwargs): augmented = aug(image=image, mask=mask, bboxes=bboxes, category_id=categories) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image_aug = cv2.cvtColor(augmented['image'], cv2.COLOR_BGR2RGB) for bbox in bboxes: visualize_bbox(image, bbox, **kwargs) for bbox in augmented['bboxes']: visualize_bbox(image_aug, bbox, **kwargs) if show_title: for bbox,cat_id in zip(bboxes, categories): visualize_titles(image, bbox, category_id_to_name[cat_id], font_scale=font_scale_orig, **kwargs) for bbox,cat_id in zip(augmented['bboxes'], augmented['category_id']): visualize_titles(image_aug, bbox, category_id_to_name[cat_id], font_scale=font_scale_aug, **kwargs) if mask is None: f, ax = plt.subplots(1, 2, figsize=(16, 8)) ax[0].imshow(image) ax[0].set_title('Original image') ax[1].imshow(image_aug) ax[1].set_title('Augmented image') else: f, ax = plt.subplots(2, 2, figsize=(16, 16)) if len(mask.shape) != 3: mask = label2rgb(mask, bg_label=0) mask_aug = label2rgb(augmented['mask'], bg_label=0) else: mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB) mask_aug = cv2.cvtColor(augmented['mask'], cv2.COLOR_BGR2RGB) ax[0, 0].imshow(image) ax[0, 0].set_title('Original image') ax[0, 1].imshow(image_aug) ax[0, 1].set_title('Augmented image') ax[1, 0].imshow(mask, interpolation='nearest') ax[1, 0].set_title('Original mask') ax[1, 1].imshow(mask_aug, interpolation='nearest') ax[1, 1].set_title('Augmented mask') f.tight_layout() plt.show() if filename is not None: f.savefig(filename) return augmented['image'], augmented['mask'], augmented['bboxes'] def find_in_dir(dirname): return [os.path.join(dirname, fname) for fname in sorted(os.listdir(dirname))]
顏色增強 - Color Augmentations
# 顏色增強處理函數 random.seed(42) image = cv2.imread('images/parrot.jpg') light = A.Compose([ A.RandomBrightness(p=1), A.RandomContrast(p=1), A.RandomGamma(p=1), # A.RGBShift(), A.CLAHE(p=1), # A.ToGray(), # A.HueSaturationValue(), ], p=1) medium = A.Compose([ A.CLAHE(p=1), A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=50, val_shift_limit=50, p=1), ], p=1) strong = A.Compose([ A.ChannelShuffle(p=1), ], p=1)
顏色增強 - light:
res = augment_and_show(light, image)
顏色增強 - medium:
res = augment_and_show(medium, image)
顏色增強 - strong:
res = augment_and_show(strong, image)
航空遙感圖像 - Inria Aerial Image Labeling Dataset:
random.seed(42) image = cv2.imread('images/inria/inria_tyrol_w4_image.jpg') mask = cv2.imread('images/inria/inria_tyrol_w4_mask.tif', cv2.IMREAD_GRAYSCALE) image, mask = image[:1024, :1024], mask[:1024,:1024] light = A.Compose([ A.RandomSizedCrop((512-100, 512+100), 512, 512), A.ShiftScaleRotate(), A.RGBShift(), A.Blur(), A.GaussNoise(), A.ElasticTransform(), A.Cutout(p=1) ],p=1) res = augment_and_show(light, image, mask)
細胞核分割 - 2018 Data Science Bowl
random.seed(42) image = cv2.imread('images/dsb2018/1a11552569160f0b1ea10bedbd628ce6c14f29edec5092034c2309c556df833e/images/1a11552569160f0b1ea10bedbd628ce6c14f29edec5092034c2309c556df833e.png') masks = [cv2.imread(x, cv2.IMREAD_GRAYSCALE) for x in find_in_dir('images/dsb2018/1a11552569160f0b1ea10bedbd628ce6c14f29edec5092034c2309c556df833e/masks')] bboxes = [cv2.boundingRect(cv2.findNonZero(mask)) for mask in masks] label_image = np.zeros_like(masks[0]) for i, mask in enumerate(masks): label_image += (mask > 0).astype(np.uint8) * i light = A.Compose([ A.RGBShift(), A.InvertImg(), A.Blur(), A.GaussNoise(), A.Flip(), A.RandomRotate90(), A.RandomSizedCrop((512 - 100, 512 + 100), 512, 512), ], bbox_params={'format':'coco', 'min_area': 1, 'min_visibility': 0.5, 'label_fields': ['category_id']}, p=1) label_ids = [0] * len(bboxes) label_names = ['Nuclei'] res = augment_and_show(light, image, label_image, bboxes, label_ids, label_names, show_title=False)
街景數據 - Mapilary Vistas
from PIL import Image image = cv2.imread('images/vistas/_HnWguqEbRCphUquTMrCCA.jpg') labels = cv2.imread('images/vistas/_HnWguqEbRCphUquTMrCCA_labels.png', cv2.IMREAD_COLOR) instances = np.array(Image.open('images/vistas/_HnWguqEbRCphUquTMrCCA_instances.png'),dtype=np.uint16) IGNORED = 65 * 256 instances[(instances//256 != 55) & (instances//256 != 44) & (instances//256 != 50)] = IGNORED image = image[1000:2500, 1000:2500] labels = labels[1000:2500, 1000:2500] instances = instances[1000:2500, 1000:2500] bboxes = [cv2.boundingRect(cv2.findNonZero((instances == instance_id).astype(np.uint8))) for instance_id in np.unique(instances) if instance_id != IGNORED] instance_labels = [instance_id // 256 for instance_id in np.unique(instances) if instance_id != IGNORED] # coco_bboxes = [list(bbox) + [label] for bbox, label in zip(bboxes, instance_labels)] # coco_bboxes = A.convert_bboxes_to_albumentations(image.shape, coco_bboxes, source_format='coco') titles = ["Bird", "Ground Animal", "Curb", "Fence", "Guard Rail", "Barrier", "Wall", "Bike Lane", "Crosswalk - Plain", "Curb Cut", "Parking", "Pedestrian Area", "Rail Track", "Road", "Service Lane", "Sidewalk", "Bridge", "Building", "Tunnel", "Person", "Bicyclist", "Motorcyclist", "Other Rider", "Lane Marking - Crosswalk", "Lane Marking - General", "Mountain", "Sand", "Sky", "Snow", "Terrain", "Vegetation", "Water", "Banner", "Bench", "Bike Rack", "Billboard", "Catch Basin", "CCTV Camera", "Fire Hydrant", "Junction Box", "Mailbox", "Manhole", "Phone Booth", "Pothole", "Street Light", "Pole", "Traffic Sign Frame", "Utility Pole", "Traffic Light", "Traffic Sign (Back)", "Traffic Sign (Front)", "Trash Can", "Bicycle", "Boat", "Bus", "Car", "Caravan", "Motorcycle", "On Rails", "Other Vehicle", "Trailer", "Truck", "Wheeled Slow", "Car Mount", "Ego Vehicle", "Unlabeled"] light = A.Compose([ A.HorizontalFlip(p=1), A.RandomSizedCrop((800 - 100, 800 + 100), 600, 600), A.GaussNoise(var_limit=(100, 150), p=1), ], bbox_params={'format':'coco', 'min_area': 1, 'min_visibility': 0.5, 'label_fields': ['category_id']}, p=1) medium = A.Compose([ A.HorizontalFlip(p=1), A.RandomSizedCrop((800 - 100, 800 + 100), 600, 600), A.MotionBlur(blur_limit=37, p=1), ], bbox_params={'format':'coco', 'min_area': 1, 'min_visibility': 0.5, 'label_fields': ['category_id']}, p=1) strong = A.Compose([ A.HorizontalFlip(p=1), A.RandomSizedCrop((800 - 100, 800 + 100), 600, 600), A.RGBShift(p=1), A.Blur(blur_limit=11, p=1), A.RandomBrightness(p=1), A.CLAHE(p=1), ], bbox_params={'format':'coco', 'min_area': 1, 'min_visibility': 0.5, 'label_fields': ['category_id']}, p=1)
街景數據增強 - light:
random.seed(13) res = augment_and_show(light, image, labels, bboxes, instance_labels, titles, thickness=2, font_scale_orig=2, font_scale_aug=1)
街景數據增強 - medium:
random.seed(13) res = augment_and_show(medium, image, labels, bboxes, instance_labels, titles, thickness=2, font_scale_orig=2, font_scale_aug=1)
街景數據增強 - strong:
random.seed(13) res = augment_and_show(strong, image, labels, bboxes, instance_labels, titles, thickness=2, font_scale_orig=2, font_scale_aug=1)
4.2 分類 Classification 示例
https://github.com/albu/albumentations/blob/master/notebooks/example.ipynb
import numpy as np import cv2 import matplotlib.pyplot as plt from albumentations import ( HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90, Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue, IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, IAASharpen, IAAEmboss, RandomContrast, RandomBrightness, Flip, OneOf, Compose ) def augment_and_show(aug, image): image = aug(image=image)['image'] plt.figure(figsize=(10, 10)) plt.imshow(image) image = cv2.imread('test.jpg') image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) plt.subplot(2, 2, 1) plt.imshow(image) plt.subplot(2, 2, 2) aug = HorizontalFlip(p=1) image_aug1 = aug(image=image)['image'] plt.imshow(image_aug1) plt.subplot(2, 2, 3) aug = IAAPerspective(scale=0.2, p=1) image_aug2 = aug(image=image)['image'] plt.imshow(image_aug2) plt.subplot(2, 2, 4) aug = ShiftScaleRotate(p=1) image_aug3 = aug(image=image)['image'] plt.imshow(image_aug3) plt.show()
def augment_flips_color(p=.5): return Compose([ CLAHE(), RandomRotate90(), Transpose(), ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.50, rotate_limit=45, p=.75), Blur(blur_limit=3), OpticalDistortion(), GridDistortion(), HueSaturationValue() ], p=p) aug = augment_flips_color(p=1) image_aug = aug(image=image)['image'] plt.subplot(1, 2, 1) plt.imshow(image) plt.subplot(1, 2, 2) plt.imshow(image_aug) plt.show()
def strong_aug(p=.5): return Compose([ RandomRotate90(), Flip(), Transpose(), OneOf([ IAAAdditiveGaussianNoise(), GaussNoise(), ], p=0.2), OneOf([ MotionBlur(p=.2), MedianBlur(blur_limit=3, p=.1), Blur(blur_limit=3, p=.1), ], p=0.2), ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=.2), OneOf([ OpticalDistortion(p=0.3), GridDistortion(p=.1), IAAPiecewiseAffine(p=0.3), ], p=0.2), OneOf([ CLAHE(clip_limit=2), IAASharpen(), IAAEmboss(), RandomContrast(), RandomBrightness(), ], p=0.3), HueSaturationValue(p=0.3), ], p=p) aug = strong_aug(p=1) image_aug = aug(image=image)['image'] plt.subplot(1, 2, 1) plt.imshow(image) plt.subplot(1, 2, 2) plt.imshow(image_aug) plt.show()
4.3 檢測 Object Detection 示例
https://github.com/albu/albumentations/blob/master/notebooks/example_bboxes.ipynb
# 導入相關庫,定義可視化函數 import os import numpy as np import cv2 from matplotlib import pyplot as plt from urllib.request import urlopen from albumentations import ( HorizontalFlip, VerticalFlip, Resize, CenterCrop, RandomCrop, Crop, Compose ) # 用於圖片上的邊界框和類別 labels 的可視化函數 BOX_COLOR = (255, 0, 0) TEXT_COLOR = (255, 255, 255) def visualize_bbox(img, bbox, class_id, class_idx_to_name, color=BOX_COLOR, thickness=2): x_min, y_min, w, h = bbox x_min, x_max, y_min, y_max = int(x_min), int(x_min + w), int(y_min), int(y_min + h) cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness) class_name = class_idx_to_name[class_id] ((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1) cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), BOX_COLOR, -1) cv2.putText(img, class_name, (x_min, y_min - int(0.3 * text_height)), cv2.FONT_HERSHEY_SIMPLEX, 0.35,TEXT_COLOR, lineType=cv2.LINE_AA) return img def visualize(annotations, category_id_to_name): img = annotations['image'].copy() for idx, bbox in enumerate(annotations['bboxes']): img = visualize_bbox(img, bbox, annotations['category_id'][idx], category_id_to_name) plt.figure(figsize=(12, 12)) plt.imshow(img) plt.imshow()
對於檢測問題,必須以指定格式定義 bbox_params. 支持的格式有兩種: coco 和 pascal_voc.
coco 的 bounding box 格式為:[x_min, y_min, width, height]
, e.g. [97, 12, 150, 200]
.
pascal_voc 的 bounding box 格式為: [x_min, y_min, x_max, y_max]
, e.g. [97, 12, 247, 212]
.
def get_aug(aug, min_area=0., min_visibility=0.): return Compose(aug, bbox_params={'format': 'coco', 'min_area': min_area, 'min_visibility': min_visibility, 'label_fields': ['category_id']}) def download_image(url): data = urlopen(url).read() data = np.frombuffer(data, np.uint8) image = cv2.imdecode(data, cv2.IMREAD_COLOR) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image image = download_image('http://images.cocodataset.org/train2017/000000386298.jpg') # Annotations for image 386298 from COCO http://cocodataset.org/#explore?id=386298 annotations = {'image': image, 'bboxes': [[366.7, 80.84, 132.8, 181.84], [5.66, 138.95, 147.09, 164.88]], 'category_id': [18, 17]} category_id_to_name = {17: 'cat', 18: 'dog'}
可視化原圖標注:
visualize(annotations, category_id_to_name)
垂直翻轉增強:
aug = get_aug([VerticalFlip(p=1)]) augmented = aug(**annotations) visualize(augmented, category_id_to_name)
水平翻轉增強:
aug = get_aug([HorizontalFlip(p=1)]) augmented = aug(**annotations) visualize(augmented, category_id_to_name)
Resize 數據增強:
aug = get_aug([Resize(p=1, height=256, width=256)]) augmented = aug(**annotations) visualize(augmented, category_id_to_name)
Albumentation 庫還支持 boxes 裁剪與刪除. 主要包括兩個參數:min_aera 和 min_visibility.
默認 min_aera 和 min_visibility 值均為 0,故,只有超出圖片尺寸之外的 boxes 才會被刪除.
CenterCrop:
aug = get_aug([CenterCrop(p=1, height=300, width=300)]) augmented = aug(**annotations) visualize(augmented, category_id_to_name)
CenterCrop with default filter:
aug = get_aug([CenterCrop(p=1, height=224, width=224)]) augmented = aug(**annotations) print(augmented['category_id']) visualize(augmented, category_id_to_name)
CenterCrop + filter with min_area:
aug = get_aug([CenterCrop(p=1, height=224, width=224)], min_area=4000) augmented = aug(**annotations) visualize(augmented, category_id_to_name)
CenterCrop + filter by visibility:
# 只返回變換后可見性大於 threshold 的 boxes aug = get_aug([CenterCrop(p=1, height=300, width=300)], min_visibility=0.3) augmented = aug(**annotations) visualize(augmented, category_id_to_name) # 如圖,變換后,dog 的 box 面積大約是原始 box 的 25%,小於 0.3,故舍棄. # 變換后,cat 的 box 面積大約是原始 box 的 36%,大於 0.3,故保留.
4.3 分割 Segmentation 示例
example_kaggle_salt.ipynb: https://github.com/albu/albumentations/blob/master/notebooks/example_kaggle_salt.ipynb
圖片和數據來自: TGS Salt Identification Challenge: https://www.kaggle.com/c/tgs-salt-identification-challenge
# 導入相關庫,定義可視化函數 import numpy as np import cv2 from matplotlib import pyplot as plt from albumentations import ( PadIfNeeded, HorizontalFlip, VerticalFlip, CenterCrop, Crop, Compose, Transpose, RandomRotate90, ElasticTransform, GridDistortion, OpticalDistortion, RandomSizedCrop, OneOf, CLAHE, RandomContrast, RandomGamma, RandomBrightness ) def visualize(image, mask, original_image=None, original_mask=None): fontsize = 18 if original_image is None and original_mask is None: f, ax = plt.subplots(2, 1, figsize=(8, 8)) ax[0].imshow(image) ax[1].imshow(mask) else: f, ax = plt.subplots(2, 2, figsize=(8, 8)) ax[0, 0].imshow(original_image) ax[0, 0].set_title('Original image', fontsize=fontsize) ax[1, 0].imshow(original_mask) ax[1, 0].set_title('Original mask', fontsize=fontsize) ax[0, 1].imshow(image) ax[0, 1].set_title('Transformed image', fontsize=fontsize) ax[1, 1].imshow(mask) ax[1, 1].set_title('Transformed mask', fontsize=fontsize) plt.show() # 原圖 image = cv2.imread('images/kaggle_salt/0fea4b5049_image.png') mask = cv2.imread('images/kaggle_salt/0fea4b5049.png', 0) print(image.shape, mask.shape) original_height, original_width = image.shape[:2] visualize(image, mask)
Padding:
aug = PadIfNeeded(p=1, min_height=128, min_width=128) augmented = aug(image=image, mask=mask) image_padded = augmented['image'] mask_padded = augmented['mask'] print(image_padded.shape, mask_padded.shape) visualize(image_padded, mask_padded, original_image=image, original_mask=mask)
(128, 128, 3) (128, 128)
CenterCrop 和 Crop:
aug = CenterCrop(p=1, height=original_height, width=original_width) augmented = aug(image=image_padded, mask=mask_padded) image_center_cropped = augmented['image'] mask_center_cropped = augmented['mask'] print(image_center_cropped.shape, mask_center_cropped.shape) assert (image - image_center_cropped).sum() == 0 assert (mask - mask_center_cropped).sum() == 0 visualize(image_padded, mask_padded, original_image=image_center_cropped, original_mask=mask_center_cropped)
(101, 101, 3) (101, 101)
x_min = (128 - original_width) // 2 y_min = (128 - original_height) // 2 x_max = x_min + original_width y_max = y_min + original_height aug = Crop(p=1, x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max) augmented = aug(image=image_padded, mask=mask_padded) image_cropped = augmented['image'] mask_cropped = augmented['mask'] print(image_cropped.shape, mask_cropped.shape) assert (image - image_cropped).sum() == 0 assert (mask - mask_cropped).sum() == 0 visualize(image_cropped, mask_cropped, original_image=image_padded, original_mask=mask_padded)
無損變換(Non destructive transformations)
對於衛星和遙感圖像,醫療圖像而言,最好是能夠不增加或者損失圖片信息,進行圖像增強變換.
有 8 種不同的方式來表示平面上的同一個方框.
可以采用 HorizontalFlip, VerticalFlip, Transpose, RandomRotate90 實現這八種數據增強.
水平翻轉:
aug = HorizontalFlip(p=1) augmented = aug(image=image, mask=mask) image_h_flipped = augmented['image'] mask_h_flipped = augmented['mask'] visualize(image_h_flipped, mask_h_flipped, original_image=image, original_mask=mask)
垂直翻轉:
aug = VerticalFlip(p=1) augmented = aug(image=image, mask=mask) image_v_flipped = augmented['image'] mask_v_flipped = augmented['mask'] visualize(image_v_flipped, mask_v_flipped, original_image=image, original_mask=mask)
隨機旋轉 90 度:
# 隨機旋轉 0,90,180,270. aug = RandomRotate90(p=1) augmented = aug(image=image, mask=mask) image_rot90 = augmented['image'] mask_rot90 = augmented['mask'] visualize(image_rot90, mask_rot90, original_image=image, original_mask=mask)
轉置 Transpose:
# 交換 X 軸和 Y 軸 aug = Transpose(p=1) augmented = aug(image=image, mask=mask) image_transposed = augmented['image'] mask_transposed = augmented['mask'] visualize(image_transposed, mask_transposed, original_image=image, original_mask=mask)
非剛性變換:彈性變換、網格變形、光學畸變(Non-rigid transformations: ElasticTransform, GridDistortion, OpticalDistortion)
在醫學圖像問題中,非剛性變換有助於數據增強.
彈性變換(ElasticTransform):
aug = ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03) augmented = aug(image=image, mask=mask) image_elastic = augmented['image'] mask_elastic = augmented['mask'] visualize(image_elastic, mask_elastic, original_image=image, original_mask=mask)
網格變形GridDistortion:
aug = GridDistortion(p=1) augmented = aug(image=image, mask=mask) image_grid = augmented['image'] mask_grid = augmented['mask'] visualize(image_grid, mask_grid, original_image=image, original_mask=mask)
光學畸變OpticalDistortion:
aug = OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5) augmented = aug(image=image, mask=mask) image_optical = augmented['image'] mask_optical = augmented['mask'] visualize(image_optical, mask_optical, original_image=image, original_mask=mask)
RandomSizedCrop:
RandomCrop (https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomCrop)和 RandomScale (https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomScale)組合.
aug = RandomSizedCrop(p=1, min_max_height=(50, 101), height=original_height, width=original_width) augmented = aug(image=image, mask=mask) image_scaled = augmented['image'] mask_scaled = augmented['mask'] visualize(image_scaled, mask_scaled, original_image=image, original_mask=mask)
數據增強 - Light,non destructive augmentations:
aug = Compose([VerticalFlip(p=0.5), RandomRotate90(p=0.5)]) augmented = aug(image=image, mask=mask) image_light = augmented['image'] mask_light = augmented['mask'] visualize(image_light, mask_light, original_image=image, original_mask=mask)
數據增強 - Medium:
aug = Compose([ OneOf([RandomSizedCrop(min_max_height=(50, 101), height=original_height, width=original_width, p=0.5), PadIfNeeded(min_height=original_height, min_width=original_width, p=0.5)], p=1), VerticalFlip(p=0.5), RandomRotate90(p=0.5), OneOf([ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03), GridDistortion(p=0.5), OpticalDistortion(p=1, distort_limit=1, shift_limit=0.5) ], p=0.8)]) augmented = aug(image=image, mask=mask) image_medium = augmented['image'] mask_medium = augmented['mask'] visualize(image_medium, mask_medium, original_image=image, original_mask=mask)
數據增強 - Strong:
添加 CLAHE, RandomBrightness, RandomContrast, RandomGamma 等只對圖片進行非空間變換處理,而不對 mask 處理.
aug = Compose([ OneOf([RandomSizedCrop(min_max_height=(50, 101), height=original_height, width=original_width, p=0.5), PadIfNeeded(min_height=original_height, min_width=original_width, p=0.5)], p=1), VerticalFlip(p=0.5), RandomRotate90(p=0.5), OneOf([ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03), GridDistortion(p=0.5), OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5) ], p=0.8), CLAHE(p=0.8), RandomContrast(p=0.8), RandomBrightness(p=0.8), RandomGamma(p=0.8)]) augmented = aug(image=image, mask=mask) image_heavy = augmented['image'] mask_heavy = augmented['mask'] visualize(image_heavy, mask_heavy, original_image=image, original_mask=mask)
4.4 Non-8-bit images 示例
https://github.com/albu/albumentations/blob/master/notebooks/example_16_bit_tiff.ipynb
from io import BytesIO from zipfile import ZipFile from urllib.request import urlopen import cv2 import numpy as np from matplotlib import pyplot as plt from albumentations import ( Compose, ToFloat, FromFloat, RandomRotate90, Flip, OneOf, MotionBlur, MedianBlur, Blur, ShiftScaleRotate, OpticalDistortion, GridDistortion, RandomContrast, RandomBrightness, HueSaturationValue, ) # 下載 16-bit TIFF 圖片 url = urlopen("http://www.brucelindbloom.com/downloads/DeltaE_16bit_gamma1.0.tif.zip") zipfile = ZipFile(BytesIO(url.read())) zip_names = zipfile.namelist() file_name = zip_names.pop() extracted_file = zipfile.open(file_name) data = np.frombuffer(extracted_file.read(), np.uint16) img = cv2.imdecode(data, cv2.IMREAD_UNCHANGED) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # OpenCV may read incorrectly some TIFF files. # Consider using `tifffile` - https://github.com/blink1073/tifffile print(img.dtype) # dtype('uint16') # Divide all values by 65535 so we can display the image using matplotlib plt.imshow(img / 65535) plt.show()
def strong_tiff_aug(p=.5): return Compose([ # albumentations 支持 uint8 和 float32 輸入. # 對於 float32, 所有的值必須在 [0.0, 1.0] 之間. # 在數據增強變換前,首先進行 `ToFloat()` 處理, 將圖像轉化為 float32 ndarray. ToFloat(), # 或者指定輸入的最大值 # ToFloat(max_value=65535.0), # 然后進行數據增強 RandomRotate90(), Flip(), OneOf([ MotionBlur(p=0.2), MedianBlur(blur_limit=3, p=0.1), Blur(blur_limit=3, p=0.1), ], p=0.2), ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=.2), OneOf([ OpticalDistortion(p=0.3), GridDistortion(p=0.1), ], p=0.2), OneOf([ RandomContrast(), RandomBrightness(), ], p=0.3), HueSaturationValue(hue_shift_limit=20, sat_shift_limit=0.1, val_shift_limit=0.1, p=0.3), # 可以采用 `FromFloat` 將增強后的圖像,轉換為原始的數據類型. # FromFloat(dtype='uint16'), # 可以指定`max_value`,則所有的值都會乘以該值. # FromFloat(dtype='uint16', max_value=65535.0), ], p=p) augmentation = strong_tiff_aug(p=0.9) augmented = augmentation(image=img) plt.figure(figsize=(14, 14)) plt.imshow(augmented['image']) plt.show()