1. 准備數據集
1.1 MNIST數據集獲取:
-
torchvision.datasets接口直接下載,該接口可以直接構建數據集,推薦
-
其他途徑下載后,編寫程序進行讀取,然后由Datasets構建自己的數據集
本文使用第一種方法獲取數據集,並使用Dataloader進行按批裝載。如果使用程序下載失敗,請將其他途徑下載的MNIST數據集 [文件] 和 [解壓文件] 放置在 <data/MNIST/raw/> 位置下,本文的程序及文件結構圖如下:
其中,model文件夾用來存儲每個epoch訓練的模型參數,根文件夾下包含model.py用於訓練模型,test.py為測試集測試,show.py為展示部分
1.2 程序部分
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time
# 1. 准備數據集
## 1.1 使用torchvision自動下載MNIST數據集
train_data = datasets.MNIST(root='data\\',
train=True,
transform=transforms.ToTensor(),
download=True)
## 1.2 構建數據集裝載器
train_loader = DataLoader(dataset=train_data,
batch_size=100,
shuffle=True,
drop_last=False,
num_workers=4)
if __name__ == "__main__":
print("===============數據統計===============")
print("訓練集樣本:",train_data.__len__(), train_data.data.shape)
【代碼解析】
-
root為存放MNIST的路徑,trian=True代表下載的為訓練集和訓練集標簽,False則代表測試集和標簽
-
transforms.ToTensor()表示將shape為(H, W, C)的 numpy 數組或 img 轉為shape為(C, H, W)的tensor,並將數值歸一化為[0,1]
-
download為True則代表自動下載,若該文件夾下已經下載,則直接跳過下載步驟
-
shuffle=True,表示對分好的batch進行洗牌操作,drop_last=True表示對最后不足batch大小的剩余樣本舍去,False表示保留
-
num_works表示每次讀取的進程數,和核心數有關
Dataset和Dataloader詳細說明,請移步:[Pytorch Dataset和Dataloader 學習筆記(二)]
2. 設計網絡結構
2.1 網絡設計
網絡結構如上圖所示,輸入圖像—>卷積1—>池化1—>卷積2—>池化2—>全連接1—>全連接2—>softmax,每次卷積通道數都增加一倍,最后送入全連接層實現分類
2.2 程序部分
# 2. Design model using class
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv_layer1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.max_pooling1 = nn.MaxPool2d(2)
self.conv_layer2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.max_pooling2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(1568, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.max_pooling1(F.relu(self.conv_layer1(x)))
x = self.max_pooling2(F.relu(self.conv_layer2(x)))
x = x.view(-1, 32*7*7)
x = F.relu(self.fc1(x))
y_hat = self.fc2(x) # CrossEntropyLoss會自動激活最后一層的輸出以及softmax處理
return y_hat
net = Net()
# 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)
【代碼解析】
-
fc1的1568維度是因為最后一次池化后的shape為32*7*7=1568
-
在最后一層,並沒有進行relu激活以及接入softmax,是因為,在CrossEntropyLoss中會自動激活最后一層的輸出以及softmax處理

CrossEntropyLoss圖參考:《PyTorch深度學習實踐》完結合集
詳細網絡結構搭建說明,請移步:Pytorch線性規划模型 學習筆記(一)
3. 迭代訓練
# 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)
# 4. Training
if __name__ == "__main__":
print("Training...")
for epoch in range(20):
strat = time.time()
total_correct = 0
for x, y in train_loader:
y_hat = net(x)
y_pre = torch.argmax(y_hat, dim=1)
total_correct += sum(torch.eq(y_pre, y)) # 統計當前epoch下的正確個數
loss = criterion(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (float(total_correct) / train_data.__len__())*100
save_path = "model/net" + str(epoch+1) + ".pth"
torch.save(obj=net.state_dict(), f=save_path)
print("epoch:", str(epoch + 1) + "/20",
" \n time:", "%.1f" % (time.time() - strat) + "s"
" train_loss:", loss.item(),
" acc:%.3f%%" % acc,)
print("we are done!")
【代碼解析】
- total_correct變量用於統計每個epoch下正確預測值的個數,每進行epoch進行一次清零
- torch.argmax(y_hat, dim=1)用於選取y_hat下每一行的最大值(每個樣本的最高得分),並返回與y相同維度的tensor
- torch.eq(y_pre, y)用於比較兩個矩陣元素是否相同,相同則返回True,不同則返回False,用於判斷預測值與真實值是否相同
- torch.save保存了每個epoch的網絡權重參數
4. 測試集預測部分
# 測試模型,測試集為test_data
import torch
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from model import Net
test_data = datasets.MNIST(root='data\\',
train=False,
transform=transforms.ToTensor(),
download=True)
test_loader = DataLoader(dataset=test_data,
batch_size=100,
shuffle=True,
drop_last=False,
num_workers=4)
if __name__ == "__main__":
print("---------------預測分析---------------")
print("測試集樣本:", test_data.__len__(), test_data.data.shape)
model = Net()
model.load_state_dict(torch.load("model/net20.pth"))
model.eval()
total_correct = 0
for x, y in test_loader:
y_hat = model(x)
y_pre = torch.argmax(y_hat, dim=1)
total_correct += sum(torch.eq(y_pre, y))
acc = (float(total_correct) / test_data.__len__())*100
print("total_test_samples:", test_data.__len__(),
" test_acc:", "%.3f%%" % acc)
經過20個epoch的訓練,在測試集上達到了98.590%的准確率,部分batch真實值與預測值展示如下:


5. 全部代碼
鏈接:鏈接:https://pan.baidu.com/s/1GGhG1Slw2Tlsgl13yzHUIw
提取碼:82l4
轉載請說明出處