torchvision 是獨立於 PyTorch 的關於圖像操作的一個工具庫,目前包括六個模塊:
1)torchvision.datasets:幾個常用視覺數據集,可以下載和加載,以及如何編寫自己的 Dataset。
2)torchvision.models:經典模型,例如 AlexNet、VGG、ResNet 等,以及訓練好的參數。
3)torchvision.transforms:常用的圖像操作,例隨機切割、旋轉、數據類型轉換、tensor 與 numpy 和 PIL Image 的互換等。
4)torchvision.ops:提供 CV 中常用的一些操作,比如 NMS、ROI_Align、ROI_Pool 等。
5)torchvision.io:提供輸入輸出的一些操作,目前針對的是視頻的寫入寫出。
6)torchvision.utils:其他工具,比如產生一個圖像網格等。
這里主要介紹下 torchvision.transforms 模塊。torchvision.transforms 是 pytorch 中的圖像預處理包。一般用 Compose 把多個步驟整合到一起。
""" transforms: list of Transform objects, 是一個列表 """ class torchvision.transforms.Compose(transforms)
事實上,Compose()類會對 transforms 列表里面的 transform 操作進行遍歷。實現的代碼很簡單,截取部分源碼如下:
def __call__(self, img): for t in self.transforms: img = t(img) return img
transforms 中的常見圖像操作:
1. transforms.ToTensor
Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]。
這個變換改變了圖像的參數順序,最終得到的圖像形狀為 $(C,H,W)$,並轉換為 Tensor 類型,歸一化至 [0,1] 是直接除以 255,每個像素變成一個 32 位
的浮點類型。
2. transforms.Normalize
""" mean (sequence) – Sequence of means for each channel. std (sequence) – Sequence of standard deviations for each channel. """ torchvision.transforms.Normalize(mean, std)
當數據量很大的時候,每個通道的數據都可以看成正態分布(大數定律),求出每個通道數據對應的均值和標准差,然后利用這兩個值將每個通道數據的分布
轉換為標准正態分布。