torchvision 之 transforms 模塊詳解


torchvision 是獨立於 PyTorch 的關於圖像操作的一個工具庫,目前包括六個模塊:

   1)torchvision.datasets:幾個常用視覺數據集,可以下載和加載,以及如何編寫自己的 Dataset。

   2)torchvision.models:經典模型,例如 AlexNet、VGG、ResNet 等,以及訓練好的參數。

   3)torchvision.transforms:常用的圖像操作,例隨機切割、旋轉、數據類型轉換、tensor 與 numpy 和 PIL Image 的互換等。

   4)torchvision.ops:提供 CV 中常用的一些操作,比如 NMS、ROI_Align、ROI_Pool 等。

   5)torchvision.io:提供輸入輸出的一些操作,目前針對的是視頻的寫入寫出。

   6)torchvision.utils:其他工具,比如產生一個圖像網格等。

這里主要介紹下 torchvision.transforms 模塊。torchvision.transforms 是 pytorch 中的圖像預處理包。一般用 Compose 把多個步驟整合到一起。

"""
transforms: list of Transform objects, 是一個列表
"""
class torchvision.transforms.Compose(transforms)

 事實上,Compose()類會對 transforms 列表里面的 transform 操作進行遍歷。實現的代碼很簡單,截取部分源碼如下:

def __call__(self, img):
    for t in self.transforms:   
        img = t(img)
    return img

transforms 中的常見圖像操作:

1. transforms.ToTensor 

   Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]。

   這個變換改變了圖像的參數順序,最終得到的圖像形狀為 $(C,H,W)$,並轉換為 Tensor 類型,歸一化至 [0,1] 是直接除以 255,每個像素變成一個 32 位

   的浮點類型。

 

2. transforms.Normalize

"""
mean (sequence) – Sequence of means for each channel.
std (sequence) – Sequence of standard deviations for each channel.
"""
torchvision.transforms.Normalize(mean, std)

   當數據量很大的時候,每個通道的數據都可以看成正態分布(大數定律),求出每個通道數據對應的均值和標准差,然后利用這兩個值將每個通道數據的分布

   轉換為標准正態分布。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM