深度學習與Pytorch入門實戰(十二)實現ResNet-18並在Cifar-10數據集上進行驗證


ResNet圖解

nn.Module詳解

1. Pytorch上搭建ResNet-18

1.1 ResNet block子模塊

import torch
from torch import nn
from torch.nn import functional as F


class ResBlk(nn.Module):
    """
    ResNet block子模塊
    """
    def __init__(self, ch_in, ch_out, stride = 1):
#         super(ResBlk, self).__init__()  # python2寫法
        # python3寫法
        super().__init__()
        
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, 
                               stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, # 輸出通道不變
                              stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)
        
        self.extra = nn.Sequential()
        # 如果輸入和輸出的通道不一致,或其步長不為 1,需要將二者轉成一致
        if ch_out != ch_in:
            # 將x的維度[b, ch_in, h, w] => [b, ch_out, h, w]
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1,  
                         stride=stride), 
                nn.BatchNorm2d(ch_out)
            )
            
    def forward(self, x):
        
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        out = self.extra(x) + out
        out = F.relu(out)
        return out

1.2 ResNet18主模塊

class ResNet18(nn.Module):
    """
    主模塊
    """
    def __init__(self):
        super(ResNet18, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
            nn.BatchNorm2d(64)
        )
        # followed 4 blocks
        self.blk1 = ResBlk(64, 128, stride=2)  # [b, 64, h, w] => [b, 128, h ,w]
        self.blk2 = ResBlk(128, 256, stride=2) # [b, 128, h, w] => [b, 256, h, w]
        self.blk3 = ResBlk(256, 512, stride=2) # [b, 256, h, w] => [b, 512, h, w]
        self.blk4 = ResBlk(512, 512, stride=2) # [b, 512, h, w] => [b, 512, h, w]
        
        self.outlayer = nn.Linear(512*1*1, 10) # 全連接層,總共10個分類
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        
        # [b, 64, h, w] => [b, 1024, h, w]
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)
        
        # 之前的特征圖尺寸為多少,只要設置為(1,1),那么最終特征圖大小都為(1,1) 
        x = F.adaptive_avg_pool2d(x, [1,1])    # [b, 512, h, w] => [b, 512, 1, 1]
        # Flatten,將四維張量轉換為二維張量之后,才能作為全連接層的輸入
        x = x.view(x.size(0), -1)   
        # Full connected layer
        x = self.outlayer(x)
        
        return x  

測試:

blk = ResBlk(64, 128, stride=4)
tmp = torch.randn(2, 64, 32, 32)
out = blk(tmp)
print('block:', out.shape)                # block: torch.Size([2, 128, 8, 8])

x = torch.randn(2, 3, 32, 32)
model = ResNet18()
out = model(x)
print('resnet:', out.shape)               # resnet: torch.Size([2, 10])
block: torch.Size([2, 128, 8, 8])
resnet: torch.Size([2, 10])

2. 訓練Cifar-10數據集

  • 所選數據集為Cifar-10,該數據集共有60000張帶標簽的彩色圖像,這些圖像尺寸32*32,分為10個類,每類6000張圖。

  • 這里面有50000張用於訓練,每個類5000張;另外10000用於測試,每個類1000張。

import  torch
from    torch.utils.data import DataLoader
from    torchvision import datasets,transforms
from    torch import nn, optim

from    resnet import ResNet18


def main():
    batchsz = 128

    # 訓練集
    cifar_train = datasets.CIFAR10('cifar', train=True, download=True, 
                                   transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
    ]))
    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)


    # 測試集
    cifar_test = datasets.CIFAR10('cifar', train=False, 
                                  transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
    ]))
    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)


    x, label = iter(cifar_train).next()
    # x: torch.Size([128, 3, 32, 32])  label: torch.Size([128])
    print('x:', x.shape, 'label:', label.shape)  

    # 定義模型-ResNet
    model = ResNet18()

    # 定義損失函數和優化方式
    criteon = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    # 訓練網絡
    for epoch in range(1000):

        model.train()                               # 訓練模式
        for batchidx, (x, label) in enumerate(cifar_train):
            # x: [b, 3, 32, 32]
            # label: [b]

            logits = model(x)                       # logits: [b, 10]
            loss = criteon(logits, label)           # 標量

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(epoch, 'loss:', loss.item())


        model.eval()                                # 測試模式
        with torch.no_grad():

            total_correct = 0                       # 預測正確的個數
            total_num = 0
            for x, label in cifar_test:
                # x: [b, 3, 32, 32]
                # label: [b]

                logits = model(x)                   # [b, 10]
                pred = logits.argmax(dim=1)         # [b]

                # [b] vs [b] => scalar tensor
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)

            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)


if __name__ == '__main__':
    main()
  • transforms.Normalize:逐channel的對圖像進行標准化

    • output = (input - mean) / std

    • mean: 各通道的均值;std:各通道的標准差;inplace:是否原地操作

  • torch.no_grad(): 是一個上下文管理器,被該語句 wrap 起來的部分將不會 track 梯度。

  • 同時 torch.no_grad() 還可以作為一個裝飾器。

  • 比如,在網絡測試的函數前加上

@torch.no_grad()
def eval():
	...

太慢了,只訓練一個epoch

view code
Files already downloaded and verified
x: torch.Size([128, 3, 32, 32]) label: torch.Size([128])
ResNet18(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(3, 3))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (blk1): ResBlk(
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (extra): Sequential(
      (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (blk2): ResBlk(
    (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (extra): Sequential(
      (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (blk3): ResBlk(
    (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (extra): Sequential(
      (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (blk4): ResBlk(
    (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (extra): Sequential()
  )
  (outlayer): Linear(in_features=512, out_features=10, bias=True)
)
0 loss: 1.0541729927062988
0 test acc: 0.5873


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM