torchvision.transforms模塊
官網地址:https://pytorch.org/docs/stable/torchvision/transforms.html#
torchvision是獨立於PyTorch的關於圖像操作的一個工具庫,目前包括六個模塊:
- torchvision.datasets:幾個常用視覺數據集,可以下載和加載,以及如何編寫自己的Dataset。
- torchvision.models:經典模型,例如AlexNet、VGG、ResNet等,以及訓練好的參數。
- torchvision.transforms:常用的圖像操作,例隨機切割、旋轉、數據類型轉換、tensor與numpy 和PIL Image的互換等。
- torchvision.ops:提供CV中常用的一些操作,比如NMS、ROI_Align、ROI_Pool等。
- torchvision.io:提供輸入輸出的一些操作,目前針對的是視頻的寫入寫出。
- torchvision.utils:其他工具,比如產生一個圖像網格等。
這里主要介紹torchvision.transforms模塊。
torchvision.transforms模塊按照功能,可分為5個部分,所有轉換均可用torchvision.transforms.Compose() 來組合。
- Transforms on PIL Image:在PIL Image上進行的轉換,比如隨機翻轉、剪切等。
- Transforms on torch.Tensor:在tensor上進行的轉換,最常用的是歸一化操作transforms.Normalize(mean, std, inplace=False)。
- Conversion Transforms:PIL.Image/numpy.ndarray與Tensor的相互轉換。
- Generic Transforms:提供自定義轉換接口。
- Functional Transforms:不同於前面的轉換,這里可以提供更細粒度的控制,需要自己提供隨機生成器或指定參數。
下面重點介紹PIL.Image/numpy.ndarray與Tensor的相互轉換,歸一化,對PIL.Image進行裁剪、縮放等操作。
1 PIL.Image/numpy.ndarray與Tensor的相互轉換
PIL.Image/numpy.ndarray轉化為Tensor,常常用在訓練模型階段的數據讀取,而Tensor轉化為PIL.Image/numpy.ndarray則用在驗證模型階段的數據輸出。
from torchvision import transforms
transform1 = transforms.Compose([
transforms.ToTensor() #PIL Image/ndarray (H,W,C) [0,255] to tensor (C,H,W) [0.0,1.0]
])
##numpy.ndarray與Tensor的相互轉換
import cv2
import numpy as np
img_path = 'Lenna.png'
img1 = cv2.imread(img_path) #img1格式為ndarray (512,512,3) uint8 BGR
img_1 = transform1(img1) #tensor (3,512,512) float32 范圍是[0.0,1.0]
#將轉換后的tensor還原成ndarray
img_11 = (img_1.numpy() * 255).astype('uint8')
img_11 = np.transpose(img_11, (1,2,0))
#判斷兩者是否相等
print((img1==img_11).all()) #True
#顯示
cv2.imshow('img_11', img_11)
cv2.waitKey()
##PIL.Image與Tensor的相互轉換
from PIL import Image
img2 = Image.open(img_path) #為PIL圖像對象,即PIL.PngImagePlugin.PngImageFile,默認RGB
img_2 = transform1(img2) #tensor (3,512,512) float32 范圍是[0.0,1.0]
#將轉換后的tensor還原成PIL Image
img_22 = transforms.ToPILImage()(img_2) #PIL.Image.Image
img_22.show()
2 歸一化 transforms.Normalize
transforms.Normalize使用該公式進行歸一化:channel = (channel-mean) / std.
上面的示例中,將transform1改成下面的transform2,即可將tensor數據的范圍由[0.0,1.0]歸一化到[-1.0, 1.0]
transform2 = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
])
3 PIL.Image的縮放裁剪等操作
transforms還提供了裁剪縮放等操作,以便進行數據增強。下面就看一個隨機裁剪的例子,這個例子中,仍然使用 Compose 將 transforms 組合在一起。注意,這里對圖像的操作主要是針對PIL.Image對象,所以需要先轉換成PIL.Image格式。
transform3 = transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.RandomCrop((300,300)),
])
img = Image.open(img_path)
img3 = transform3(img)
img3.show()
Reference: