超簡單!pytorch入門教程(四):准備圖片數據集


在訓練神經網絡之前,我們必須有數據,作為資深伸手黨,必須知道以下幾個數據提供源:

一、CIFAR-10


CIFAR-10圖片樣本截圖

CIFAR-10是多倫多大學提供的圖片數據庫,圖片分辨率壓縮至32x32,一共有10種圖片分類,均進行了標注。適合監督式學習。CIFAR-10數據下載頁面

二、ImageNet


imagenet首頁

ImageNet首頁

三、ImageFolder


imagefolder首頁

ImageFolder首頁

四、LSUN Classification


LSUN Classification

LSUN 圖片下載地址

五、COCO (Captioning and Detection)


coco首頁

COCO首頁地址

六、我們進入正題

為了方便加載以上五種數據庫的數據,pytorch團隊幫我們寫了一個torchvision包。使用torchvision就可以輕松實現數據的加載和預處理。

我們以使用CIFAR10為例:

導入torchvision的庫:

import torchvision

import torchvision.transforms as transforms  # transforms用於數據預處理

使用datasets.CIFAR10()函數加載數據庫。CIFAR10有60000張圖片,其中50000張是訓練集,10000張是測試集。

#訓練集,將相對目錄./data下的cifar-10-batches-py文件夾中的全部數據(50000張圖片作為訓練數據)加載到內存中,若download為True時,會自動從網上下載數據並解壓trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=None)

下面簡單講解root、train、download、transform這四個參數

1.root,表示cifar10數據的加載的相對目錄

2.train,表示是否加載數據庫的訓練集,false的時候加載測試集

3.download,表示是否自動下載cifar數據集

4.transform,表示是否需要對數據進行預處理,none為不進行預處理

由於美帝路途遙遠,靠命令台進程下載100多M的數據速度很慢,所以我們可以自己去到cifar10的官網上把CIFAR-10 python version下載下來,然后解壓為cifar-10-batches-py文件夾,並復制到相對目錄./data下。(若設置download=True,則程序會自動從網上下載cifar10數據到相對目錄./data下,但這樣小伙伴們可能要等一個世紀了),並對訓練集進行加載(train=True)。

 


如圖所示,在腳本文件下建一個data文件夾,然后把數據集文件夾丟到里面去就好了,注意cifar-10-batches-py文件夾名字不能自己任意改。

我們在寫完上面三行代碼后,在寫一行print一下trainset的大小看看:

print len(trainset)

#結果:50000

我們在訓練神經網絡時,使用的是mini-batch(一次輸入多張圖片),所以我們在使用一個叫DataLoader的工具為我們將50000張圖分成每四張圖一分,一共12500份的數據包。

#將訓練集的50000張圖片划分成12500份,每份4張圖,用於mini-batch輸入。shffule=True在表示不同批次的數據遍歷時,打亂順序(這個需要在訓練神經網絡時再來講)。num_workers=2表示使用兩個子進程來加載數據

import torch

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=False, num_workers=2)

那么我們就寫下了這幾行代碼:


print的結果為50000和12500

下面我們需要對數據進行預處理,什么是預處理?為什么要預處理?如果不知道的小盆友可以看看下面幾個鏈接,或許對你有幫助。神經網絡為什么要歸一化深度學習-----數據預處理。還無法理解也沒關系,只要記住,預處理會幫助我們加快神經網絡的訓練。

在pytorch中我們預處理用到了transforms函數:

transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

compose函數會將多個transforms包在一起。

我們的transforms有好幾種,例如transforms.ToTensor(), transforms.Scale()等,完整列表在。好好學習吧!

我只講現在用到了兩種:

1.ToTensor是指把PIL.Image(RGB) 或者numpy.ndarray(H x W x C) 從0到255的值映射到0到1的范圍內,並轉化成Tensor格式。

2.Normalize(mean,std)是通過下面公式實現數據歸一化

channel=(channel-mean)/std

那么經過上面兩個轉換一折騰,我們的數據中的每個值就變成了[-1,1]的數了。


1到22行,我們從硬盤中讀取數據,並將數據預處理(第13行,transform=transform),然后轉換成4張圖為一批的數據結構。26行到47行,為我們顯示出一個圖片例子,可有可無,不再作代碼解釋。

源代碼下載

 

 

 

 

更多torchvision加載其他數據庫方法



作者:Zen_君
鏈接:http://www.jianshu.com/p/8da9b24b2fb6
來源:簡書
著作權歸作者所有。商業轉載請聯系作者獲得授權,非商業轉載請注明出處。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM