1. Pytorch上搭建ResNet-18
1.1 ResNet block子模塊
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
"""
ResNet block子模塊
"""
def __init__(self, ch_in, ch_out, stride = 1):
# super(ResBlk, self).__init__() # python2寫法
# python3寫法
super().__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3,
stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, # 輸出通道不變
stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
# 如果輸入和輸出的通道不一致,或其步長不為 1,需要將二者轉成一致
if ch_out != ch_in:
# 將x的維度[b, ch_in, h, w] => [b, ch_out, h, w]
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1,
stride=stride),
nn.BatchNorm2d(ch_out)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.extra(x) + out
out = F.relu(out)
return out
1.2 ResNet18主模塊
class ResNet18(nn.Module):
"""
主模塊
"""
def __init__(self):
super(ResNet18, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
nn.BatchNorm2d(64)
)
# followed 4 blocks
self.blk1 = ResBlk(64, 128, stride=2) # [b, 64, h, w] => [b, 128, h ,w]
self.blk2 = ResBlk(128, 256, stride=2) # [b, 128, h, w] => [b, 256, h, w]
self.blk3 = ResBlk(256, 512, stride=2) # [b, 256, h, w] => [b, 512, h, w]
self.blk4 = ResBlk(512, 512, stride=2) # [b, 512, h, w] => [b, 512, h, w]
self.outlayer = nn.Linear(512*1*1, 10) # 全連接層,總共10個分類
def forward(self, x):
x = F.relu(self.conv1(x))
# [b, 64, h, w] => [b, 1024, h, w]
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
# 之前的特征圖尺寸為多少,只要設置為(1,1),那么最終特征圖大小都為(1,1)
x = F.adaptive_avg_pool2d(x, [1,1]) # [b, 512, h, w] => [b, 512, 1, 1]
# Flatten,將四維張量轉換為二維張量之后,才能作為全連接層的輸入
x = x.view(x.size(0), -1)
# Full connected layer
x = self.outlayer(x)
return x
測試:
blk = ResBlk(64, 128, stride=4)
tmp = torch.randn(2, 64, 32, 32)
out = blk(tmp)
print('block:', out.shape) # block: torch.Size([2, 128, 8, 8])
x = torch.randn(2, 3, 32, 32)
model = ResNet18()
out = model(x)
print('resnet:', out.shape) # resnet: torch.Size([2, 10])
block: torch.Size([2, 128, 8, 8])
resnet: torch.Size([2, 10])
2. 訓練Cifar-10數據集
-
所選數據集為Cifar-10,該數據集共有60000張帶標簽的彩色圖像,這些圖像尺寸32*32,分為10個類,每類6000張圖。
-
這里面有50000張用於訓練,每個類5000張;另外10000用於測試,每個類1000張。
import torch
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
from torch import nn, optim
from resnet import ResNet18
def main():
batchsz = 128
# 訓練集
cifar_train = datasets.CIFAR10('cifar', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]))
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
# 測試集
cifar_test = datasets.CIFAR10('cifar', train=False,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]))
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
x, label = iter(cifar_train).next()
# x: torch.Size([128, 3, 32, 32]) label: torch.Size([128])
print('x:', x.shape, 'label:', label.shape)
# 定義模型-ResNet
model = ResNet18()
# 定義損失函數和優化方式
criteon = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(model)
# 訓練網絡
for epoch in range(1000):
model.train() # 訓練模式
for batchidx, (x, label) in enumerate(cifar_train):
# x: [b, 3, 32, 32]
# label: [b]
logits = model(x) # logits: [b, 10]
loss = criteon(logits, label) # 標量
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch, 'loss:', loss.item())
model.eval() # 測試模式
with torch.no_grad():
total_correct = 0 # 預測正確的個數
total_num = 0
for x, label in cifar_test:
# x: [b, 3, 32, 32]
# label: [b]
logits = model(x) # [b, 10]
pred = logits.argmax(dim=1) # [b]
# [b] vs [b] => scalar tensor
correct = torch.eq(pred, label).float().sum().item()
total_correct += correct
total_num += x.size(0)
acc = total_correct / total_num
print(epoch, 'test acc:', acc)
if __name__ == '__main__':
main()
-
transforms.Normalize
:逐channel的對圖像進行標准化-
output = (input - mean) / std
-
mean: 各通道的均值;std:各通道的標准差;inplace:是否原地操作
-
-
torch.no_grad()
: 是一個上下文管理器,被該語句 wrap 起來的部分將不會 track 梯度。 -
同時
torch.no_grad()
還可以作為一個裝飾器。 -
比如,在網絡測試的函數前加上
@torch.no_grad()
def eval():
...
太慢了,只訓練一個epoch
view code
Files already downloaded and verified
x: torch.Size([128, 3, 32, 32]) label: torch.Size([128])
ResNet18(
(conv1): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(3, 3))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(blk1): ResBlk(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(extra): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(blk2): ResBlk(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(extra): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(blk3): ResBlk(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(extra): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(blk4): ResBlk(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(extra): Sequential()
)
(outlayer): Linear(in_features=512, out_features=10, bias=True)
)
0 loss: 1.0541729927062988
0 test acc: 0.5873