CNN(Pytorch版)實現GTA5的自動駕駛——第二節(torchvision的model使用)


這里我不想涉及太多CNN基礎介紹,因為內容太多了,如果有興趣可以參考以下鏈接學習

  1. 李沐老師的《動手學深度學習》
  2. B站視頻《動手學深度學習》

因為torchvision已經包含了一些model,所以不必在意網絡架構的設計,只需要調用即可

以Alexnet為例

import torchvision.models as models
alexnet = models.alexnet()
print(alexnet) #通過print查看網絡結構

我們可以看到在 classfier下的最后一個Linear的input=4066, output=1000 在本系列中,我們需要的output是9.通過使用add_module來增加

# 給alexnet增加一個模塊,名為our_output, 輸入為1000(也就是上一層的輸出),輸出為9(本系列需要的結果)
alexnet.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))

完整代碼如下

import torchvision.models as models
import torch.nn as nn
alexnet = models.alexnet()
alexnet.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
print(alexnet)

輸出結果,可以看到新增的一層

把alexnet封裝成一個函數,方便我們后續調用

def get_alex():
    alexnet = models.alexnet()
    alexnet.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
    return alexnet

按照同樣的方法,可以構建如下的網絡(截止到2021年12月13日torchvision提供的model)

上述提供的網絡太多了,我選擇了其中的幾個網絡。並用構建alexnet的方法構建了幾個網絡

import torch
from torchvision import models
import torch.nn as nn
from conf import config

cf = config()


def get_alex():
    alexnet = models.alexnet(pretrained=True)
    alexnet.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
    return alexnet
def get_res18():
    resnet18 = models.resnet18(pretrained=True)
    resnet18.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
    return resnet18
def get_widerest():
    wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
    wide_resnet50_2.add_module('our_output', nn.Linear(in_features=1000, out_features=9, bias=True))
    return wide_resnet50_2

def build_model(model_name):
    if model_name == 'resnet18':
        model = get_res18().to(cf.DEVICE)
    elif model_name == 'alexnet':
        model = get_alex().to(cf.DEVICE)
    elif model_name == 'wide_resnet50_2':
        model = get_widerest().to(cf.DEVICE)
    return model

參考鏈接

  1. torchvision官方鏈接:https://pytorch.org/vision/stable/models.html#classification


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM