Pytorch卷積神經網絡識別手寫數字集


卷積神經網絡目前被廣泛地用在圖片識別上, 已經有層出不窮的應用, 如果你對卷積神經網絡充滿好奇心,這里為你帶來pytorch實現cnn一些入門的教程代碼

#首先導入包

import torch
from torch.autograd import Variable
import torch.nn as nn
import torchvision
import torch.utils.data as Data

 

 

#一、數據准備

#訓練數據:用了torchvision.datasets.MNIST,root是文件路徑,train為True(這是訓練數據),transform是把圖像數據轉換為張量,download(如果本地已有該文件選擇false,沒有就選擇true)

train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,transform=torchvision.transforms.ToTensor(),download=False)

#訓練數據:同上,train為False(這是測試數據)

test_data = torchvision.datasets.MNIST(root='./mnist/',train=False)

# "訓練數據加載器":dataset為訓練數據,shuflle為打亂數據的順序,batch_size是讓數據50個為一組

train_loader = Data.DataLoader(dataset=train_data,shuffle=True,batch_size=50)

test_data.test_data.size()

torch.Size([10000, 28, 28])

#測試數據 test_data下的test_data為測試數據,因為下面conv2d輸入的為4維數據,所以此處用torch.unsqueeze升維

test_x = Variable(torch.unsqueeze(test_data.test_data,dim=1),volatile=True).type(torch.FloatTensor)

#測試數據目標值

test_y = test_data.test_labels

 

 

#二、實現模型

class CNN(nn.Module):
  def __init__(self):
    super(CNN,self).__init__()

    #conv2d參數:輸入1維,輸出16維,5個卷積核(kernel),步長(stride)為1,padding是2(如果想要 con2d 出來的圖片長寬沒有變化, padding=(kernel_size-1)/2 當 stride=1)
    self.conv1 = nn.Sequential(nn.Conv2d(1,16,5,1,2),nn.ReLU(),nn.MaxPool2d(2))
    self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),nn.ReLU(),nn.MaxPool2d(2))

    #Linear參數:輸入維數,輸出分的種類數
    self.out = nn.Linear(32*7*7,10)
  def forward(self,x):
    x1 = self.conv1(x)
    x2 = self.conv2(x1)

    #這里給x3降為2維可以讓linear函數使用
    x3 = x2.view(x2.size(0),-1)
    out = self.out(x3)
    return out

 

#自動調整參數,最優化模型

cnn = CNN()

optimizer = torch.optim.Adam(cnn.parameters(),lr = 0.02)
loss_func = nn.CrossEntropyLoss()

 

#三、訓練模型

for step,(x,y) in enumerate(train_loader):
  x = Variable(x)
  y = Variable(y)
  out = cnn(x)
  loss = loss_func(out,y)

  #以下為固定操作,為了訓練每一條數據,不斷調整參數
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

 

 

 

#四、測試

predict = cnn(test_x[:10])
res = torch.max(predict,1)[1]

res #測試數據

tensor([7, 2, 1, 0, 4, 1, 4, 9, 9, 9])

test_y[:10] #真實數據

tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])

 

#在這里我們發現前十個數據分類准確率達到90

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM