torchvision是PyTorch的一個視覺工具包,提供了很多圖像處理的工具。
datasets使用ImageFolder工具(默認PIL Image圖像),獲取定制化的圖片並自動生成類別標簽。如裁剪、旋轉、標准化、歸一化等(使用transforms工具)。
DataLoader可以把datasets數據集打亂,分成batch,並行加速等。
一、datasets獲取原圖或格式化的圖,自動命名標簽
1.1 獲取原圖片
使用torchvision.datasets中的ImageFolder工具,功能:
1、文件夾名就是類別名
2、從上到下自動為文件夾自動創建標簽,0、1、2、...。class_to_idx、imgs屬性可以查看。
3、返回每一幅圖的data、label
from torchvision.datasets import ImageFolder dataset=ImageFolder("E:/data/dogcat_2/train/") #獲取路徑,返回的是所有圖的data、label print(dataset.class_to_idx) #查看類別名,及對應的標簽。 print(dataset.imgs) #查看路徑里所有的圖片,及對應的標簽
print(dataset[0][1]) #第1張圖的label dataset[0][0] #第1張圖的data
1.2 獲取定制化的圖片,啟用ImageFolder的transform參數
使用torchvision的transforms工具,常用功能:
Resize——調整大小
CenterCrop、RandomCrop、RandomSizedCrop——裁剪
Pad——填充
ToTensor——PIL Image轉Tensor,自動[0,255]歸一化到[0,1]
Normalize——標准化,即減均值,除以標准差
ToPILImage——Tensor轉PIL Image
這些操作可以放到一起——Compose
from torchvision import transforms as T #設置格式化條件 transform=T.Compose([T.Resize((200,200)), #縮放為200*200方形 T.RandomHorizontalFlip(), #水平翻轉 T.ToTensor(), #PIL Image轉Tensor,[0,255]自動歸一化為[0,1] T.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]) #標准化,減均值除標准差 ]) #啟用ImageFolder的transform參數,獲取格式化圖像 dataset=ImageFolder("E:/data/dogcat_2/train/",transform=transform) dataset[0][0].size() #查看圖像大小,3*224*224
#展示圖像,乘標准差加均值,再轉回PIL Image(上述過程的逆過程) show=T.ToPILImage() show(dataset[0][0]*0.5+0.5)
二、DataLoader處理datasets
from torch.utils.data import DataLoader dataloader=DataLoader( dataset,batch_size=4,shuffle=True,num_workers=2 ) #4幅圖為1個batch,打亂,2個進程加速 #### 顯示第1個batch的4幅圖(隨機) from torchvision.transforms import ToPILImage from torchvision.utils import make_grid dataiter = iter(dataloader) #DataLoader是可迭代的 (images, labels) = dataiter.next() #第一個batch print(labels) #打印標簽 show=ToPILImage() show(make_grid(images*0.5+0.5)).resize((4*100,100)) #以100*100展示第一個batch