Pytorch下微調網絡模型進行圖像分類


利用ImageNet下的預訓練權重采用遷移學習策略,能夠實現模型快速訓練,提高圖像分類性能。下面以vgg和resnet網絡模型為例,微調最后的分類層進行分類。

 注意,微調只對分類層(也就是全連接層)的參數進行更新,前面的參數需要被凍結。

(1)微調VGG模型進行圖像分類(以vgg16為例)

import torch
import torch.nn as nn
import torchvision.models as models

classes_num = 200 # 數據集的類別數

model = models.vgg16(pretrained=True)
for parameter in model.parameters():
parameter.required_grad = False
model.classifier = nn.Sequential(nn.Linear(512*7*7, 4096),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(4096, classes_num))
model = model.cuda()
print(model)

 

(2)微調ResNet模型進行圖像分類(以ResNet-34為例)

import torch
import torch.nn as nn
import torchvision.models as models

classes_num = 200 # 數據集的類別數

model = models.resnet34(pretrained=True)
for parameter in model.parameters():
parameter.required_grad = False
model.classifier = nn.Linear(512, classes_num)
model = model.cuda()
print(model)

  


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM