Pytorch 搭建 LeNet-5 網絡


1 數據集

Mnist 數據集是一個手寫數字圖片數據集,數據集的下載和解讀詳見 Mnist數據集解讀

這里為了對接 pytorch 的神經網絡,需要將數據集制作成可以批量讀取的 tensor 數據。采用 torch.utils.data.Dataset 構建。

data.py

import os
import numpy as np
from torch.utils.data import Dataset
import gzip


class Mnist(Dataset):
    def __init__(self, root, train=True, transform=None):

        # 根據是否為訓練集,得到文件名前綴
        self.file_pre = 'train' if train == True else 't10k'
        self.transform = transform

        # 生成對應數據集的圖片和標簽文件路徑
        self.label_path = os.path.join(root,
                                       '%s-labels-idx1-ubyte.gz' % self.file_pre)
        self.image_path = os.path.join(root,
                                       '%s-images-idx3-ubyte.gz' % self.file_pre)

        # 讀取文件數據,返回圖片和標簽
        self.images, self.labels = self.__read_data__(
            self.image_path,
            self.label_path)

    def __read_data__(self, image_path, label_path):
        # 數據集讀取
        with gzip.open(label_path, 'rb') as lbpath:
            labels = np.frombuffer(lbpath.read(), np.uint8,
                                   offset=8)
        with gzip.open(image_path, 'rb') as imgpath:
            images = np.frombuffer(imgpath.read(), np.uint8,
                                   offset=16).reshape(len(labels), 28, 28)
        return images, labels

    def __getitem__(self, index):
        image, label = self.images[index], int(self.labels[index])

        # 如果需要轉成 tensor 則使用 tansform
        if self.transform is not None:
            image = self.transform(np.array(image))  # 此處需要用 np.array(image)
        return image, label

    def __len__(self):
        return len(self.labels)


if __name__ == '__main__':

    # 生成實例
    train_set = Mnist(
        root=r'H:\\Dataset\\Mnist',
        train=False,
    )

    # 取一組數據並展示
    (data, label) = train_set[0]
    import matplotlib.pyplot as plt
    plt.imshow(data.reshape(28, 28), cmap='gray')
    plt.title('label is :{}'.format(label))
    plt.show()

總體思路:指定Mnist數據集的存儲路徑后,根據是否為訓練集,找到對應的壓縮包(圖像和標簽),解壓文件並讀取數據,利用 Dataset 構造迭代器,從而實現根據索引號返回一組圖像和標簽的數據。

Dataset 是一個抽象類,需要繼承並重寫。其中,根據Mnist數據集文件的命名和存儲結構,構造了一個__read_data__ 私有函數,用來讀取數據,返回圖像和標簽值;在__init__ 中,初始化數據集,獲取到原始的數據;在__getitem__ 中,根據 index ,返回一組圖像和標簽,這里可以對圖像進行變換(可選,例如轉成tensor, 歸一化等等);在 __len__ 中返回數據集的樣本個數。

為了看懂最后輸出的內容,生成了一個實例,取出一組數據,並展示,結果如下:
圖1.1 從數據集中取出一張圖展示

2 模型構建

圖2.1 LeNet-5模型架構圖

LeNet-5 神經網絡一共五層,其中卷積層和池化層可以考慮為一個整體,網絡的結構為 :

輸入 → 卷積 → 池化 → 卷積 → 池化 → 全連接 → 全連接 → 全連接 → 輸出。

pytorch 中,圖像數據集的存儲順序為:(batch, channels, height, width),依次為批大小、通道數、高度、寬度。所以,按照網絡結構,各層的參數和輸入輸出關系,可以整理得到下表:

表2.1 LeNet-5模型參數表
操作 操作參數 輸入/輸出尺寸
input batch: ?
channels: 1
height: 28
width: 28
input:(batch, 1, 28, 28)
output: (batch, 1, 28, 28)
conv1 in_channels: 1
out_channels: 6
kernel_size: 5×5
padding: 0
stride: 1
input: (batch, 1, 28, 28)
output:(batch, 6, 24, 24)
pool1 kernel_size: 2×2 input:(batch, 6, 24, 24)
output:(batch, 6, 12, 12)
conv2 in_channels: 6
out_channels: 16
kernel_size: 5×5
padding: 0
stride: 1
input:(batch, 6, 12, 12)
output:(batch, 16, 8, 8)
pool2 kernel_size: 2×2 input:(batch, 16, 8, 8)
output:(batch, 16, 4, 4)
fc1 in: 16×4×4
out: 120
input:(batch, 16*4*4)
output:(batch, 120)
fc2 in: 120
out: 84
input:(batch, 120)
output:(batch,84)
fc3 in: 84
out:10
input:(batch,84)
output:(batch, 10)

如上表所示,輸入的Mnist數據集是灰度圖,通道為1,長和寬都為28。經過pytorch處理后,可以生成批量數據,從而多出一個batch的維度數據。

這里需要特別注意的是,從第二次卷積池化后,與全連接層fc1進行數據傳遞時,是先把池化pool2的輸出,除了batch之外的其他維度數據,展平到一個維度。然后送入全連接層。而這個數據的大小跟輸入大小有關,因此在設計時,需要仔細推算每一層的輸出。

由上面的分析,就可以搭建網絡了。

model.py

import torch.nn as nn
import torch.nn.functional as F


class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5,self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16*4*4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self,x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 16*4*4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


if __name__ == '__main__':
    net = LeNet5()
    print(net)

主要思路

網絡的構建需要繼承 torch.nn.Module ,在 _init__ 中和forward 中其實都是可以定義網絡的,但是,一般是在__init__ 里定義一些主要的操作,然后在 forward 里輸入數據,進行前向傳播的表達。其中展平的操作利用 view() 實現,前面的 -1 表示默認,即batch的大小,后面則是其余維度展平后的大小。

為了看清楚網絡的各層參數,將其打印了:

LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

3 訓練與測試

神經網絡的訓練主要包括了導入批數據,前向傳播,反向傳播,權重更新,如此循環迭代。遍歷到一定的epoch數量后停止,得到訓練好的模型。

隨后,將圖像送進網絡進行測試即可。

main.py

import torch
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import DataLoader
from data import Mnist
from model import LeNet5


# 生成訓練集
train_set = Mnist(
    root=r'H:\\Dataset\\Mnist',
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1037,), (0.3081,))
    ])
)
train_loader = DataLoader(
    dataset=train_set,
    batch_size=32,
    shuffle=True
)


# 實例化一個網絡
net = LeNet5()

# 定義損失函數和優化器
loss_function = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(
    net.parameters(),
    lr=0.001,
    momentum=0.9
)

# 3 訓練模型
loss_list = []
for epoch in range(10):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, start=0):

        images, labels = data                       # 讀取一個batch的數據
        optimizer.zero_grad()                       # 梯度清零,初始化
        outputs = net(images)                       # 前向傳播
        loss = loss_function(outputs, labels)       # 計算誤差
        loss.backward()                             # 反向傳播
        optimizer.step()                            # 權重更新
        running_loss += loss.item()                 # 誤差累計

        # 每300個batch 打印一次損失值
        if batch_idx % 300 == 299:
            print('epoch:{} batch_idx:{} loss:{}'
                  .format(epoch+1, batch_idx+1, running_loss/300))
            loss_list.append(running_loss/300)
            running_loss = 0.0                  #誤差清零

print('Finished Training.')


# 打印損失值變化曲線
import matplotlib.pyplot as plt
plt.plot(loss_list)
plt.title('traning loss')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.show()


# 測試
test_set = Mnist(
    root='H:\\Dataset\\Mnist',
    train=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1037,), (0.3081,))
    ])
)
test_loader = DataLoader(
    dataset=test_set,
    batch_size=32,
    shuffle=True
)

correct = 0  # 預測正確數
total = 0    # 總圖片數

for data in test_loader:
    images, labels = data
    outputs = net(images)
    _, predict = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predict == labels).sum()

print('測試集准確率 {}%'.format(100*correct // total))


# 測試自己手動設計的手寫數字
from PIL import Image
I = Image.open('8.jpg')
L = I.convert('L')
plt.imshow(L, cmap='gray')

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1037,), (0.3081,))
])
    
im = transform(L)  # [C, H, W]
im = torch.unsqueeze(im, dim=0)  # [N, C, H, W]

with torch.no_grad():
    outputs = net(im)
    _, predict = torch.max(outputs.data, 1)
    print(predict)

總體思路

利用 torch.utils.data.DataLoader 從數據集中划分批次,然后打亂順序,每次送入一個批次的數據到神經網絡進行訓練,每300個批次計算一次損失值。訓練結束后,測試了在測試集上的准確率。最后又測試自己手動制作的單一的手寫數字圖像。

結果如下:
圖3.1 loss 迭代曲線

測試集准確率 99%
tensor([8])

圖3.2 測試手寫圖

4 運行界面

圖4.1 運行界面
spyder 簡直是 matlab轉python的絕佳選擇!

5 總結

本次LeNet-5 網絡的是最基礎的,其構建過程是所有其他網絡的基本范式。通過這次搭建,我們熟悉了如何導入自己制作的數據集(雖然是數據是網上下載的,但也需要一定的過程轉成可用的數據格式);了解網絡的搭建方法,分析了其參數和輸入輸出關系,弄懂了其中卷積池化后與全連接的之間維度上的匹配問題;最后成功地實現了較高的識別准確率。

需要改進的地方:

  1. 模型評估改進,希望生成具體的測試集和訓練集損失函數迭代曲線,以及准確率的迭代曲線。

  2. 代碼優化,希望將數據集、訓練、測試、評價、應用等環節模塊化。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM