卷積神經網絡初體驗——使用pytorch搭建CNN


〇、基本流程

加載數據->搭建模型->訓練->測試

 

一、加載數據

通過使用torch.utils.data.DataLoader和torchvision.datasets兩個模塊可以很方便地去獲取常用數據集(手寫數字MNIST、分類CIFAR),以及將其加載進來。

1.加載內置數據集

 
         
 import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

 1 train_loader = torch.utils.data.DataLoader(
 2     torchvision.datasets.MNIST('mnist_data', train=True, download=True,
 3                                transform=torchvision.transforms.Compose([
 4                                    torchvision.transforms.ToTensor(),
 5                                    torchvision.transforms.Normalize(
 6                                        (0.1307,), (0.3081,))
 7                                ])),
 8     batch_size=batch_size, shuffle=True)
 9 # train 是否為訓練集,         download 數據集不存在時是否下載數據集
10 # ToTensor() 轉換成tensor格式,Normalize() 歸一化,將數據作(data-mean)/std
11 # batch_size 加載一批數量,    shuffle 是否打散數據

2.加載自定義數據集

用torchvision.datasets.ImageFolder加載圖片數據集

 

二、搭建模型

一個模型可以表示為python的一個類,這個類要繼承torch.nn.modules.Module,並且實現forward( )方法

 1 class Lenet5(nn.Module):
 2     """
 3     for CIFAR10
 4     """
 5     def __init__(self):
 6         super(Lenet5, self).__init__()
 7 
 8         # 兩層卷積
 9         self.conv_unit = nn.Sequential(
10             
11             # 3表示input,可以理解為圖片的通道數量,即我的卷積核一次要到幾個tensor去作卷積
12             # 6表示有多少個卷積核
13             # stride表示卷積核移動步長,padding表示邊緣擴充
14             nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),    # 卷積 15             nn.AvgPool2d(kernel_size=2, stride=2, padding=0),      # 池化 16 
17             nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
18             nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
19         )
20 
21         # 3層全連接層
22         self.fc_unit = nn.Sequential(
23             nn.Linear(16*5*5, 120),
24             nn.ReLU(),
25             nn.Linear(120, 84),
26             nn.ReLU(),
27             nn.Linear(84, 10)
28         )
31 
32 
33     def forward(self, x):                # 數據從此進來,經過定義好的各層網絡,最終輸出 34         batchsz = x.size(0)
35         x = self.conv_unit(x)
36         x = x.view(batchsz, 16*5*5)          # 經過卷積層后,對數據維度作處理,以適應全連接層 37         logits = self.fc_unit(x)
38         return logits

 

三、訓練

訓練過程可以認為是對參數優化的過程,通過輸入數據,得到輸出,計算損失(誤差),再經過誤差反向傳播得到梯度信息,以更新參數。

 1     # 實例模型、配置損失函數、優化器
 2     device = torch.device('cuda')                       # 轉為GPU上執行
 3     model = Lenet5().to(device)                         # 實例化模型
 4     criteon = nn.CrossEntropyLoss().to(device)          # 損失函數
 5     optimizer = optim.Adam(model.parameters(), lr=1e-3) # 優化器
 6     print(model)
 7 
 8     # 訓練
 9     for epoch in range(1000):                   # 迭代1000次
10         model.train()                           # 模型切換為訓練模式
11         for batchidx, (x, label) in enumerate(cifar_train):
12             x, label = x.to(device), label.to(device)
13             logist = model(x)                   # 得到模型的輸出
14             loss = criteon(logist, label)       # 計算損失
15             optimizer.zero_grad()               # 舊梯度清零
16             loss.backward()                     # 誤差反向傳播
17             optimizer.step()                    # 梯度更新
18 
19         print(epoch, loss.item())

 

四、測試

當模型訓練完畢后,進行數據測試。

 1      model.eval()                 # 切換為驗證模式  2         with torch.no_grad():            # 不進行梯度更新  3             total_correct = 0                   # 記錄正確的數據量
 4             total_num = 0                       # 記錄總數據量
 5             for x, label in cifar_test:
 6                 x, label = x.to(device), label.to(device)
 7                 logist = model(x)               # 獲得模型輸出
 8                 pred = logist.argmax(dim=1)     # 取值最大的下標,在這里恰好對應圖片標簽
 9                 
10                 # eq(pred, label)表示比較預測值和實際標簽
11                 total_correct += torch.eq(pred, label).float().sum().item()
12                 total_num += x.size(0)
13             acc = total_correct / total_num     # 計算正確率

 

五、其他

1、保存與加載模型

即當模型訓練好之后,將模型保存,下一次可以直接使用。

1 torch.save(model.state_dict(), 'best.mdl')      # 保存模型
2 
3 model.load_state_dict(torch.load('best.mdl'))   # 加載模型

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM