[Pytorch]Pytorch加載預訓練模型(轉)


轉自:https://blog.csdn.net/Vivianyzw/article/details/81061765

東風的地方

1. 直接加載預訓練模型

在訓練的時候可能需要中斷一下,然后繼續訓練,也就是簡單的從保存的模型中加載參數權重:


   
   
   
           
  1. net = SNet()
  2. net.load_state_dict(torch.load( "model_1599.pkl"))

這種方式是針對於之前保存模型時以保存參數的格式使用的:

torch.save(net.state_dict(), "model/model_1599.pkl")

  
  
  
          

pytorch官網更推薦上述模型保存方法,也據說這種方式比下一種更快一點。

下面介紹第二種模型保存和加載的方式:


   
   
   
           
  1. net = SNet()
  2. torch.save(net, "model_1599.pkl")
  3. snet = torch.load( "model_1599.pkl")

這種方式會將整個網絡保存下來,數據量會更大,會消耗更多的時間,占用內存也更高。

2. 加載一部分預訓練模型

模型可能是一些經典的模型改掉一部分,比如一般算法中提取特征的網絡常見的會直接使用vgg16的features extraction部分,也就是在訓練的時候可以直接加載已經在imagenet上訓練好的預訓練參數,這種方式實現如下:


   
   
   
           
  1. net = SNet()
  2. model_dict = net.state_dict()
  3. vgg16 = models.vgg16(pretrained= True)
  4. pretrained_dict = vgg16.state_dict()
  5. pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
  6. model_dict.update(pretrained_dict)
  7. net.load_state_dict(model_dict)
也就是在網絡中state_dict部分,屬於vgg16的,替換成vgg16預訓練模型里的參數(代碼里的k:v for k,v in pretrained_dict.items() if k in model_dict),其他保持不變。

3. 微調經典網絡

因為pytorch中的torchvision給出了很多經典常用模型,並附加了預訓練模型。利用好這些訓練好的基礎網絡可以加快不少自己的訓練速度。

首先比如加載vgg16(帶有預訓練參數的形式):


   
   
   
           
  1. import torchvision.models as models
  2. vgg16 = models.vgg16(pretrained= True)

比如,網絡第一層本來是Conv2d(3, 64, 3, 1, 1),想修改成Conv2d(4, 64, 3, 1 ,1),那直接賦值就可以了:


   
   
   
           
  1. import torch.nn as nn
  2. vgg16.features[ 0]=nn.Conv2d( 4, 64, 3, 1, 1)

4. 修改經典網絡

這個比上面微調修改的地方要多一些,但是想介紹一下這樣的修改方式。

先簡單介紹一下我需要需改的部分,在vgg16的基礎模型下,每一個卷積都要加一個dropout層,並將ReLU激活函數換成PReLU,最后兩層的Pooling層stride改成1。直接上代碼:


   
   
   
           
  1. def feature_layer():
  2. layers = []
  3. pool1 = [ '4', '9', '16']
  4. pool2 = [ '23', '30']
  5. vgg16 = models.vgg16(pretrained= True).features
  6. for name, layer in vgg16._modules.items():
  7. if isinstance(layer, nn.Conv2d):
  8. layers += [layer, nn.Dropout2d( 0.5), nn.PReLU()]
  9. elif name in pool1:
  10. layers += [layer]
  11. elif name == pool2[ 0]:
  12. layers += [nn.MaxPool2d( 2, 1, 1)]
  13. elif name == pool2[ 1]:
  14. layers += [nn.MaxPool2d( 2, 1, 0)]
  15. else:
  16. continue
  17. features = nn.Sequential(*layers)
  18. #feat3 = features[0:24]
  19. return features

大概的思路就是,創建一個新的網絡(layers列表), 遍歷vgg16里每一層,如果遇到卷積層(if isinstance(layer, nn.Conv2d)就先把該層(Conv2d)保持原樣加進去,隨后增加一個dropout層,再加一個PReLU層。然后如果遇到最后兩層pool,就修改響應參數加進去,其他的pool正常加載。 最后將這個layers列表轉成網絡的nn.Sequential的形式,最后返回features。然后再你的新的網絡層就可以用以下方式來加載:


   
   
   
           
  1. class SNet(nn.Module):
  2. def __init__(self):
  3. super(SNet, self).__init__()
  4. self.features = feature_layer()
  5. def forward(self, x):
  6. x = self.features(x)
  7. return x


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM