目錄
這里我不想涉及太多CNN基礎介紹,因為內容太多了,如果有興趣可以參考以下鏈接學習
因為torchvision已經包含了一些model,所以不必在意網絡架構的設計,只需要調用即可
以Alexnet為例
import torchvision.models as models
alexnet = models.alexnet()
print(alexnet) #通過print查看網絡結構
我們可以看到在 classfier下的最后一個Linear的input=4066, output=1000 在本系列中,我們需要的output是9.通過使用add_module
來增加
# 給alexnet增加一個模塊,名為our_output, 輸入為1000(也就是上一層的輸出),輸出為9(本系列需要的結果)
alexnet.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
完整代碼如下
import torchvision.models as models
import torch.nn as nn
alexnet = models.alexnet()
alexnet.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
print(alexnet)
輸出結果,可以看到新增的一層
把alexnet封裝成一個函數,方便我們后續調用
def get_alex():
alexnet = models.alexnet()
alexnet.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
return alexnet
按照同樣的方法,可以構建如下的網絡(截止到2021年12月13日torchvision提供的model)
上述提供的網絡太多了,我選擇了其中的幾個網絡。並用構建alexnet的方法構建了幾個網絡
import torch
from torchvision import models
import torch.nn as nn
from conf import config
cf = config()
def get_alex():
alexnet = models.alexnet(pretrained=True)
alexnet.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
return alexnet
def get_res18():
resnet18 = models.resnet18(pretrained=True)
resnet18.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
return resnet18
def get_widerest():
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
wide_resnet50_2.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
return wide_resnet50_2
def build_model(model_name):
if model_name == 'resnet18':
model = get_res18().to(cf.DEVICE)
elif model_name == 'alexnet':
model = get_alex().to(cf.DEVICE)
elif model_name == 'wide_resnet50_2':
model = get_widerest().to(cf.DEVICE)
return model
參考鏈接
- torchvision官方鏈接:https://pytorch.org/vision/stable/models.html#classification