用pytorch進行CIFAR-10數據集分類


CIFAR-10.(Canadian Institute for Advanced Research)是由 Alex Krizhevsky、Vinod Nair 與 Geoffrey Hinton 收集的一個用於圖像識別的數據集,60000個32*32的彩色圖像,50000個training data,10000個 test data 有10類,飛機、汽車、鳥、貓、鹿、狗、青蛙、馬、船、卡車,每類6000張圖。與MNIST相比,色彩、顏色噪點較多,同一類物體大小不一、角度不同、顏色不同。

 

 先要對該數據集進行分類

步驟如下
1.使用torchvision加載並預處理CIFAR-10數據集、
2.定義網絡
3.定義損失函數和優化器
4.訓練網絡並更新網絡參數
5.測試網絡

 1 import torchvision as tv            #里面含有許多數據集
 2 import torch
 3 import torchvision.transforms as transforms    #實現圖片變換處理的包
 4 from torchvision.transforms import ToPILImage
 5 
 6 #使用torchvision加載並預處理CIFAR10數據集
 7 show = ToPILImage()         #可以把Tensor轉成Image,方便進行可視化
 8 transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean = (0.5,0.5,0.5),std = (0.5,0.5,0.5))])#把數據變為tensor並且歸一化range [0, 255] -> [0.0,1.0]
 9 trainset = tv.datasets.CIFAR10(root='data1/',train = True,download=True,transform=transform)
10 trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=0)
11 testset = tv.datasets.CIFAR10('data1/',train=False,download=True,transform=transform)
12 testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=True,num_workers=0)
13 classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
14 (data,label) = trainset[100]
15 print(classes[label])#輸出ship
16 show((data+1)/2).resize((100,100))
17 dataiter = iter(trainloader)
18 images, labels = dataiter.next()
19 print(' '.join('%11s'%classes[labels[j]] for j in range(4)))
20 show(tv.utils.make_grid((images+1)/2)).resize((400,100))#make_grid的作用是將若干幅圖像拼成一幅圖像
21 
22 #定義網絡
23 import torch.nn as nn
24 import torch.nn.functional as F
25 class Net(nn.Module):
26     def __init__(self):
27         super(Net,self).__init__()
28         self.conv1 = nn.Conv2d(3,6,5)
29         self.conv2 = nn.Conv2d(6,16,5)
30         self.fc1 = nn.Linear(16*5*5,120)
31         self.fc2 = nn.Linear(120,84)
32         self.fc3 = nn.Linear(84,10)
33     def forward(self,x):
34         x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
35         x = F.max_pool2d(F.relu(self.conv2(x)),2)
36         x = x.view(x.size()[0],-1)
37         x = F.relu(self.fc1(x))
38         x = F.relu(self.fc2(x))
39         x = self.fc3(x)
40         return  x
41 
42 net = Net()
43 print(net)
44 
45 #定義損失函數和優化器
46 from torch import optim
47 criterion  = nn.CrossEntropyLoss()#定義交叉熵損失函數
48 optimizer = optim.SGD(net.parameters(),lr = 0.001,momentum=0.9)
49 
50 #訓練網絡
51 from torch.autograd  import Variable
52 for epoch in range(2):
53     running_loss = 0.0
54     for i, data in enumerate(trainloader, 0):#enumerate將其組成一個索引序列,利用它可以同時獲得索引和值,enumerate還可以接收第二個參數,用於指定索引起始值
55         inputs, labels = data
56         inputs, labels = Variable(inputs), Variable(labels)
57         optimizer.zero_grad()
58         outputs = net(inputs)
59         loss  = criterion(outputs, labels)
60         loss.backward()
61         optimizer.step()
62         running_loss += loss.item()
63         if i % 2000 ==1999:
64             print('[%d, %5d] loss: %.3f'%(epoch+1,i+1,running_loss/2000))
65             running_loss = 0.0
66 print("----------finished training---------")
67 dataiter = iter(testloader)
68 images, labels = dataiter.next()
69 print('實際的label: ',' '.join('%08s'%classes[labels[j]] for j in range(4)))
70 show(tv.utils.make_grid(images/2 - 0.5)).resize((400,100))#?????
71 outputs = net(Variable(images))
72 _, predicted = torch.max(outputs.data,1)#返回最大值和其索引
73 print('預測結果:',' '.join('%5s'%classes[predicted[j]] for j in range(4)))
74 correct = 0
75 total = 0
76 for data in testloader:
77     images, labels = data
78     outputs = net(Variable(images))
79     _, predicted = torch.max(outputs.data, 1)
80     total +=labels.size(0)
81     correct +=(predicted == labels).sum()
82 print('10000張測試集中的准確率為: %d %%'%(100*correct/total))
83 if torch.cuda.is_available():
84     net.cuda()
85     images = images.cuda()
86     labels = labels.cuda()
87     output = net(Variable(images))
88     loss = criterion(output, Variable(labels))

學習率太大會很難逼近最優值,所以要注意在數據集小的情況下學習率盡量小一些,epoch盡量大一些。

這個例子是陳雲的深度學習pytorch框架書上的一個demo,運行該代碼需要注意的是數據集的下載問題,因為運行程序很可能數據集下載很慢或者直接下載失敗,因此推薦使用迅雷根據指定網址直接下載,半分鍾就可以下載好。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM