7. pytorch 現有網絡模型的使用與修改和模型的保存與加載


  PyTorch是一個開源的Python機器學習庫,基於Torch,用於自然語言處理等應用程序。他提供了大量的模型供我們所使用,如下圖所示:


下面,我們選擇其中一個網絡進行使用,介紹如何使用、並修改 pytorch 本身為我們提供的現有網絡。最后介紹一下模型的保存和修改。

pytorch 現有網絡的使用與修改

  下面我們以 VGG(Very Deep Convolutional Networks for Large-Scale Image Recognition)的使用為例,進行介紹該網絡。

     VGG 16 簡介

  VGG16網絡是14年牛津大學計算機視覺組和Google DeepMind公司研究員一起研發的深度網絡模型。該網絡一共有16個訓練參數的網絡,該網絡的具體網絡結構如下所示:


  不難看出,該網絡主要用於對 224 x 224 的圖像進行 1000 分類。下面我們查看 VGG 在 pytorch 上的官方文檔。

     VGG 16 doc

  從幫助文檔中,我們可以清楚的看到 pytorch 為我們提供了各種版本的 VGG,我們選擇 VGG 16 進行查看。


     VGG16 的簡單使用

  從 vgg 16的幫助文檔可以得知,該模型訓練的數據是 ImageNet,我們進入 torchvision.datasets 查看 ImageNet


但是該數據集實在是太大了,根本下不了,還是不搞了。建立一個該網絡的模型查看參數: ```python import torch import torchvision import torch.nn as nn # import torchvision.models

vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True)
vgg_model_original = torchvision.models.vgg16(pretrained=False, progress=True)

print(vgg_model_original)
print(vgg_model_pretrained)

vgg_model_pretrained.add_module()



<p align="center">
	<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113091424359-823205952.png" style="zoom:100%"/>
</p>
<br/>


<p align="center">
	<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113091441147-673036772.png" style="zoom:100%"/>
</p>
<br/>

仔細查看這個網絡的組成,你可以發現,組成該網絡的一個個小 module 就是我們之前所介紹過的`Conv2d`, `ReLU`, `MaxPool2d`, `Linear`, `Dropout` 等等函數,


### &nbsp;&nbsp;&nbsp;&nbsp; VGG16 模型修改
&nbsp;&nbsp;經過上面的代碼,我們可以較為輕松的看到 VGG16 神經網絡的結構框架,那么我們如何修改別人已經寫好的模型呢?
&nbsp;&nbsp;想要修改別人寫好的模型,主要有一下這幾種操作


<p align="center">
	<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113092149496-1776250931.png" style="zoom:100%"/>
</p>
<br/>

選中模型,進行 add_module() 或者是直接對模型進行修改

<p align="center">
	<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113094536565-291308585.png" style="zoom:100%"/>
</p>
<br/>


<p align="center">
	<img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113094752626-1549811574.png" style="zoom:100%"/>
</p>
<br/>

```python
import torch
import torchvision
import torch.nn as nn
# import torchvision.models

vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True)
vgg_model_original = torchvision.models.vgg16(pretrained=False, progress=True)

print(vgg_model_original)
print(vgg_model_pretrained)
# vgg_model_pretrained.add_module()
vgg_model_original.classifier.add_module('15', nn.Linear(in_features=1000, out_features=10, bias=True))
print(vgg_model_original)
vgg_model_original.classifier[7] = nn.Linear(in_features=1000, out_features=15, bias=True)
print(vgg_model_original)

根據上訴代碼,我們就將 1000 分類問題的網絡修改成了 10 分類或者是 15 分類問題的網絡了。

模型的保存和加載

  當我們利用數據將模型訓練好之后,往往需要保存模型。同時,當我們創建模型的時候,也可能需要加載我們之前已經訓練好的參數,下面我來介紹一下操作方法。

     保留模型結構和模型參數

通過 torch.save() 和 torch.load() 進行保存模型和參數

import torch
import torchvision
import torch.nn as nn
# import torchvision.models

vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True)

torch.save(vgg_model_pretrained, "../../models_param/vgg_model_pretrained.pth")

vgg_model_load = torch.load(f="../../models_param/vgg_model_pretrained.pth")
print(111)

打一個斷點,查看保存模型和加載模型的參數情況



     僅保留模型參數

  同樣是使用 save 和 load 參數,但是用法有所不同,他所保存的是一個模型參數,以字典dict 的形式保存

import torch
import torchvision
import torch.nn as nn
# import torchvision.models

vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True)

torch.save(vgg_model_pretrained.state_dict(), "../../models_param/vgg_model_pretrained_method2.pth")

vgg_model_load_method2 = torchvision.models.vgg16()
vgg_model_load_method2.load_state_dict(torch.load("../../models_param/vgg_model_pretrained_method2.pth"))
print("this is a breakpoint!")

斷點查看 save 和 load 模型的參數情況



一模一樣,沒有問題。

Author:luckylight(xyg)
Date:2021/11/13


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM