pytorch實現網絡的保存和提取


代碼如下:

#實現網絡的保存和提取
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt

#設置隨機種子實現結果復現,在神經網絡中,參數默認是進行隨機初始化的。
# 不同的初始化參數往往會導致不同的結果,當得到比較好的結果時我們通常希望這個結果是可以復現的,
# 在pytorch中,通過設置隨機數種子也可以達到這么目的
torch.manual_seed(1)

#生成數據
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())
x,y = Variable(x, requries_grad=False), Variable(y,requries_grad=False)

#保存網絡
def save():
    net1 = torch.nn.Sequential(   #順序搭建層
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
    loss_func = torch.nn.MSELoss()

    for t in range(100):
        prediction = net1(x)
        loss = loss_func(prediction, y) #預測值和真實值
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    torch.save(net1, 'net1.pkl')                       #保存整個神經網絡到net1.pkl中
    torch.save(net1.state_dict(), 'net1_paras.pkl')    #保存網絡里的參數到net1_paras.pkl中

    #畫圖


#提取方式1
#提取整個網絡
def restore_net():
    net2 = torch.load('net1.pkl')    
    
#提取方式2
#先建一個一樣的網絡,再把保存的參數放進去
def restore_paras():
    net3 = torch.nn.Sequential(   #建立和net1一樣的層,不過參數肯定不同
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    net3.load_state_dict(torch.load('net1_paras.pkl'))

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM