Pytorch學習:CIFAR-10分類


最近在學習Pytorch,先照着別人的代碼過一遍,加油!!!

 

加載數據集

# 加載數據集及預處理
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
import torch as t
show=ToPILImage() #可以將Tensor轉成Image,方便可視化

划分數據集為訓練集和測試集

#定義對數據的預處理
transform=transforms.Compose([
    transforms.ToTensor(),  #轉為Tensor
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)), #歸一化
])

#訓練集
trainset=tv.datasets.CIFAR10(
    root='/home/cy/data',
    train=True,
    download=True,
    transform=transform
)

trainloader=t.utils.data.DataLoader(
    trainset,
    batch_size=4,
    shuffle=True,
    num_workers=2
)

testset=tv.datasets.CIFAR10(
    '/home/cy/data/',
    train=False,
    download=True,
    transform=transform
)

testloader=t.utils.data.DataLoader(
    testset,
    batch_size=4,
    shuffle=False,
    num_workers=2
)

classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
Files already downloaded and verified
Files already downloaded and verified

可視化看下圖片效果
(data, label)=trainset[100]
print(classes[label])

#(data+1)是為了還原被歸一化的數據
show((data+1)/2).resize((100,100))

展示一個mini-batch中的圖片

dataiter=iter(trainloader)
images,labels=dataiter.next() #返回4張圖片及標簽
print(' '.join('%11s'%classes[labels[j]] for j in range(4)))
show(tv.utils.make_grid((images+1)/2)).resize((400,100))

 

定義網絡結構,挺方便的

## 定義網絡
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(3,6,5)
        self.conv2=nn.Conv2d(6,16,5)
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)
        
        
    def forward(self,x):
        x=F.max_pool2d(F.relu(self.conv1(x)),(2,2))
        x=F.max_pool2d(F.relu(self.conv2(x)),2)
        x=x.view(x.size()[0],-1)
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
        return x

net=Net()
print(net)
Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

定義損失函數和優化器
## 定義損失函數和優化器
from torch import optim
criterion=nn.CrossEntropyLoss()  # 交叉熵損失函數
optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9) #隨機梯度下降,stochastic gradient descent

開始訓練網絡

一共有三個步驟。輸入數據,前向傳播+反向傳播,更新參數

from torch.autograd import Variable

for epoch in range(2):
    running_loss=0.0
    for i,data in enumerate(trainloader,0):
        #輸入數據
        inputs,labels=data
        inputs,labels=Variable(inputs),Variable(labels)
        
        #梯度清零
        optimizer.zero_grad()
        
        #forward+backward
        outputs=net(inputs)
        loss=criterion(outputs,labels)
        loss.backward()
        
        #更新參數
        optimizer.step()
        
        #打印log信息
        #running_loss +=loss.data[0]
        running_loss +=loss.item()
        if i%2000 ==1999:   #每2000個batch打印一次訓練狀態
            print('[%d, %5d] loss: %.3f' \
                 %(epoch+1,i+1,running_loss / 2000))
            running_loss=0.0
print('Finished Training')

 

檢查一下網絡在一個batch內的效果如何

## 檢驗網絡效果
dataiter=iter(testloader)
images,labels=dataiter.next() #一個batch返回4張圖片
print('實際的label: ',' '.join(\
            '%08s'%classes[labels[j]] for j in range(4)))
show(tv.utils.make_grid(images/2 -0.5)).resize((400,100))

# 計算網絡預測的label
outputs=net(Variable(images))
_,predicted=t.max(outputs.data,1)
print('預測結果: ',' '.join('%5s'\
        % classes[predicted[j]] for j in range(4)))

 

測試集上計算正確率

correct=0
total=0
for data in testloader:
    images,labels=data
    outputs=net(Variable(images))
    _,predicted=t.max(outputs.data,1)
    total +=labels.size(0)
    correct +=(predicted==labels).sum()
    
print('1000張測試集中的准確率為: %d  %%' %(100* correct/total))
1000張測試集中的准確率為: 52  %

 

可以看到,在CIFAR-10上的正確率為52%,網絡訓練還是有些效果的。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM