1.保存神經網絡
速度較慢
2.只保存神經網絡參數
速度快,這種方式將會提取所有的參數, 然后再放到你的新建網絡中
代碼:
import torch import matplotlib.pyplot as plt import torch.nn.functional as F # 激勵函數都在這 torch.manual_seed(1) # reproducible # 假數據 用了torch.manual_seed(1)所以假數據是固定不變的 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1) y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) def save(): #構建神經網絡 net1=torch.nn.Sequential( torch.nn.Linear(1,10), torch.nn.ReLU(), torch.nn.Linear(10,1) ) #訓練 optimizer = torch.optim.SGD(net1.parameters(),lr=0.5) loss_func=torch.nn.MSELoss() for t in range(100): prediction=net1(x) loss=loss_func(prediction,y) optimizer.zero_grad() loss.backward() optimizer.step() #保存 torch.save(net1, 'net.pkl')#整個網絡 torch.save(net1.state_dict(),'net_params.pkl')#網絡的參數 def restore_net(): # restore entire net1 to net2 net2 = torch.load('net.pkl') prediction = net2(x) def restore_params(): # 新建 net3 net3 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) ) # 將保存的參數復制到 net3 net3.load_state_dict(torch.load('net_params.pkl')) prediction = net3(x) save() restore_net() restore_params()
輸出圖: