pytorch實戰(2)-----回歸例子


一、回歸任務介紹:

擬合一個二元函數 y = x ^ 2.

二、步驟:

  1. 導入包
  2. 創建數據
  3. 構建網絡
  4. 設置優化器和損失函數
  5. 前向和后向傳播訓練網絡
  6. 畫圖

三、代碼:

導入包:

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt

創建數據

#torch中的數據要是二維的,unsqueeze是將一維數據轉化成二維數據
tmp = torch.linspace(-1,1,100)
x = torch.unsqueeze(tmp,dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())

print(tmp)  #torch.Size([100])
print(x)  #torch.Size([100, 1])
#轉成向量
x,y = Variable(x),Variable(y)

   查看數據圖像:

plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()

構建網絡

#Net類繼承了Module這個模塊
class Net(torch.nn.Module):
    def __init__(self,n_feature,n_hidden,n_output):
        #在搭建模型之前需要繼承的一些信息,super表示繼承nn.Module的信息,此步驟必須有
        super(Net,self).__init__()
        self.hidden = torch.nn.Linear(n_feature,n_hidden)
        self.predict = torch.nn.Linear(n_hidden,n_output)
    #神經網絡前向傳遞的一個過程,流程圖
    def forward(self,x):
        x = F.relu(self.hidden(x))
        x = self.predict(x)
        return x
net = Net(1,10,1)
plt.ion()
plt.show()
#可以看到搭建的圖流程
print(net)
 打印的結果:
Net(
  (hidden): Linear(in_features=1, out_features=10, bias=True)
  (predict): Linear(in_features=10, out_features=1, bias=True)
)

 設置優化器和損失函數

optimizer = torch.optim.SGD(net.parameters(),lr = 0.5)  #傳入網絡的參數來優化它們
loss_func = torch.nn.MSELoss()

前向和后向傳播訓練網絡

for t in range(100):
    
    #forward
    prediction = net(x)
    loss = loss_func(prediction,y)  #預測值pre在前,實際值y在后,不然結果會不一樣
    
    #backward()
    optimizer.zero_grad()   #梯度全部設為0
    loss.backward()  #loss計算參數的梯度
    optimizer.step()  #采用優化器以lr=0.5來優化梯度
    
###########################以下為可視化過程##################################
    if t % 5 == 0:
        plt.cla()
        plt.scatter(x.data.numpy(),y.data.numpy())
        plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
        plt.text(0.5,0,'Loss=%.4f' % loss.data[0],fontdict={'size':20,'color':'red'})
        plt.pause(0.1)
plt.ioff()
plt.show()

訓練結果:

第一次:

最后一次:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM