MNIST數據集上卷積神經網絡的簡單實現(使用PyTorch)


設計的CNN模型包括一個輸入層,輸入的是MNIST數據集中28*28*1的灰度圖

兩個卷積層,

第一層卷積層使用6個3*3的kernel進行filter,步長為1,填充1.這樣得到的尺寸是(28+1*2-3)/1+1=28,即6個28*28的feature map

在后面進行池化,尺寸變為14*14

第二層卷積層使用16個5*5的kernel,步長為1,無填充,得到(14-5)/1+1=10,即16個10*10的feature map

池化后尺寸為5*5

后面加兩層全連接層,第一層將16*5*5=400個神經元線性變換為120個,第二層將120個變為84個

最后的輸出層將84個輸出為10個種類

 代碼如下:

###MNIST數據集上卷積神經網絡的簡單實現###

# 配置庫
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets

# 配置參數
torch.manual_seed(1)  # 設置隨機數種子,確保結果可重復
batch_size = 128  # 批處理大小
learning_rate = 1e-2  # 學習率
num_epoches = 10  # 訓練次數

# 加載MNIST數據
# 下載訓練集MNIST手寫數字訓練集
train_dataset = datasets.MNIST(
    root='./data',  # 數據保持的位置
    train=True,  # 訓練集
    transform=transforms.ToTensor(),  # 一個取值范圍是【0,255】的PIL.Image
    # 轉化成取值范圍是[0,1.0]的torch.FloatTensor
    download=True
)
test_dataset = datasets.MNIST(
    root='./data',
    train=False,  # 測試集
    transform=transforms.ToTensor()
)
# 數據的批處理中,尺寸大小為batch_size
# 在訓練集中,shuffle必須設置為True,表示次序是隨機的
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


# 創建CNN模型
# 使用一個類來創建,這個模型包括1個輸入層,2個卷積層,2個全連接層和1個輸出層。
# 其中卷積層構成為卷積(conv2d)->激勵函數(ReLU)->池化(MaxPooling)
# 全連接層由線性層(Linear)構成

# 定義卷積神經網絡模型
class Cnn(nn.Module):
    def __init__(self, in_dim, n_class):  # 28*28*1
        super(Cnn, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_dim, 6, 3, stride=1, padding=1),  # 28*28
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),  # 14*14
            nn.Conv2d(6, 16, 5, stride=1, padding=0),  # 10*10*16
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)  # 5*5*16
        )
        self.fc = nn.Sequential(
            nn.Linear(400, 120),
            nn.Linear(120, 84),
            nn.Linear(84, n_class)
        )

    def forward(self, x):
        out = self.conv(x)
        out = out.view(out.size(0), 400)  # 400=5*5*16
        out = self.fc(out)
        return out


# 圖片大小是28*28,10是數據的種類
model = Cnn(1, 10)
# 打印模型,呈現網絡結構
print(model)

# 模型訓練,將img\label都用Variable包裝起來,放入model中計算out,最后計算loss和正確率

# 定義loss和optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# 開始訓練
for epoch in range(num_epoches):
    running_loss = 0.0
    running_acc = 0.0
    for i, data in enumerate(train_loader, 1):  # 批處理
        img, label = data
        img = Variable(img)
        label = Variable(label)
        # 前向傳播
        out = model(img)
        loss = criterion(out, label)  # loss
        running_loss += loss.item() * label.size(0)
        # total loss,由於loss是batch取均值的,需要把batch_size乘進去
        _, pred = torch.max(out, 1)  # 預測結果
        num_correct = (pred == label).sum()  # 正確結果的數量
        #accuracy = (pred == label).float().mean()  # 正確率
        running_acc += num_correct.item()  # 正確結果的總數
        # 后向傳播
        optimizer.zero_grad()  # 梯度清零,以免影響其他batch
        loss.backward()  # 后向傳播,計算梯度
        optimizer.step()  # 利用梯度更新W,b參數

    # 打印一個循環后,訓練集合上的loss和正確率
    print('Train {} epoch, Loss:{:.6f},Acc:{:.6f}'.format(epoch + 1, running_loss / (len(train_dataset)),
                                                          running_acc / (len(train_dataset))))

# 在測試集上測試識別率
# 模型測試
model.eval()
# 由於訓練和測試BatchNorm,Dropout配置不同,需要說明是否模型測試
eval_loss = 0
eval_acc = 0
for data in test_loader:  # test set批處理
    img, label = data
    with torch.no_grad():
        img = Variable(img)
    # volatile確定你是否調用.backward(),
    # 測試中不需要label=Variable(label,volatile=True)
    #不需要梯度更新改為with torch.no_grad()
    out = model(img)
    loss = criterion(out, label)  # 計算loss
    eval_loss += loss.item() * label.size(0)  # total loss
    _, pred = torch.max(out, 1)  # 預測結果
    num_correct = (pred == label).sum()  # 正確結果
    eval_acc += num_correct.item()  # 正確結果總數
print('Test loss:{:.6f},Acc:{:.6f}'.format(eval_loss / (len(test_dataset)), eval_acc * 1.0 / (len(test_dataset))))

  


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM