torchvision.transforms模塊介紹


torchvision.transforms模塊

官網地址:https://pytorch.org/docs/stable/torchvision/transforms.html#

torchvision是獨立於PyTorch的關於圖像操作的一個工具庫,目前包括六個模塊:

  • torchvision.datasets:幾個常用視覺數據集,可以下載和加載,以及如何編寫自己的Dataset。
  • torchvision.models:經典模型,例如AlexNet、VGG、ResNet等,以及訓練好的參數。
  • torchvision.transforms:常用的圖像操作,例隨機切割、旋轉、數據類型轉換、tensor與numpy 和PIL Image的互換等。
  • torchvision.ops:提供CV中常用的一些操作,比如NMS、ROI_Align、ROI_Pool等。
  • torchvision.io:提供輸入輸出的一些操作,目前針對的是視頻的寫入寫出。
  • torchvision.utils:其他工具,比如產生一個圖像網格等。

這里主要介紹torchvision.transforms模塊。

torchvision.transforms模塊按照功能,可分為5個部分,所有轉換均可用torchvision.transforms.Compose() 來組合。

  • Transforms on PIL Image:在PIL Image上進行的轉換,比如隨機翻轉、剪切等。
  • Transforms on torch.Tensor:在tensor上進行的轉換,最常用的是歸一化操作transforms.Normalize(mean, std, inplace=False)。
  • Conversion Transforms:PIL.Image/numpy.ndarray與Tensor的相互轉換。
  • Generic Transforms:提供自定義轉換接口。
  • Functional Transforms:不同於前面的轉換,這里可以提供更細粒度的控制,需要自己提供隨機生成器或指定參數。

下面重點介紹PIL.Image/numpy.ndarray與Tensor的相互轉換,歸一化,對PIL.Image進行裁剪、縮放等操作。

1 PIL.Image/numpy.ndarray與Tensor的相互轉換

PIL.Image/numpy.ndarray轉化為Tensor,常常用在訓練模型階段的數據讀取,而Tensor轉化為PIL.Image/numpy.ndarray則用在驗證模型階段的數據輸出。

from torchvision import transforms

transform1 = transforms.Compose([
    transforms.ToTensor() #PIL Image/ndarray (H,W,C) [0,255] to tensor (C,H,W) [0.0,1.0]
    ]) 


##numpy.ndarray與Tensor的相互轉換
import cv2
import numpy as np

img_path = 'Lenna.png'
img1 = cv2.imread(img_path) #img1格式為ndarray  (512,512,3)  uint8  BGR
img_1 = transform1(img1) #tensor  (3,512,512)  float32  范圍是[0.0,1.0]
#將轉換后的tensor還原成ndarray
img_11 = (img_1.numpy() * 255).astype('uint8')
img_11 = np.transpose(img_11, (1,2,0))
#判斷兩者是否相等 
print((img1==img_11).all()) #True
#顯示
cv2.imshow('img_11', img_11)
cv2.waitKey()


##PIL.Image與Tensor的相互轉換
from PIL import Image

img2 = Image.open(img_path) #為PIL圖像對象,即PIL.PngImagePlugin.PngImageFile,默認RGB
img_2 = transform1(img2) #tensor  (3,512,512)  float32  范圍是[0.0,1.0] 
#將轉換后的tensor還原成PIL Image
img_22 = transforms.ToPILImage()(img_2)  #PIL.Image.Image
img_22.show()

2 歸一化 transforms.Normalize

transforms.Normalize使用該公式進行歸一化:channel = (channel-mean) / std.

上面的示例中,將transform1改成下面的transform2,即可將tensor數據的范圍由[0.0,1.0]歸一化到[-1.0, 1.0]

transform2 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
    ])

3 PIL.Image的縮放裁剪等操作

transforms還提供了裁剪縮放等操作,以便進行數據增強。下面就看一個隨機裁剪的例子,這個例子中,仍然使用 Compose 將 transforms 組合在一起。注意,這里對圖像的操作主要是針對PIL.Image對象,所以需要先轉換成PIL.Image格式。

transform3 = transforms.Compose([
    transforms.ToTensor(), 
    transforms.ToPILImage(),
    transforms.RandomCrop((300,300)),
    ])

img = Image.open(img_path)
img3 = transform3(img)
img3.show()

Reference:


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM