看了Movan大佬的文字教程讓我對pytorch的基本使用有了一定的了解,下面簡單介紹一下二分類用pytorch的基本實現!
希望詳細的注釋能夠對像我一樣剛入門的新手來說有點幫助!
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.autograd import Variable
n_data = torch.ones(100,2) #生成一個100行2列的全1矩陣
x0 = torch.normal(2*n_data,1)#利用100行兩列的全1矩陣產生一個正態分布的矩陣均值和方差分別是(2*n_data,1)
y0 = torch.zeros(100)#給x0標定標簽確定其分類0
x1 = torch.normal(-2*n_data,1) #利用同樣的方法產生第二個數據類別
y1 = torch.ones(100)#但是x1數據類別的label就標定為1
x = torch.cat((x0,x1),0).type(torch.FloatTensor)#cat方法就是將兩個數據樣本聚合在一起(x0,x1),0這個屬性就是第幾個維度進行聚合
y = torch.cat((y0,y1),).type(torch.LongTensor)#y也是一樣
x = Variable(x)#將它們裝載到Variable的容器里
y = Variable(y)#將它們裝載到Variable的容器里
#plt.scatter(x.data.numpy()[:,0],x.data.numpy()[:,1],c=y.data.numpy(),s=100,lw=0,cmap='RdYlGn')
#plt.show()
class Net(torch.nn.Module):#開始搭建一個神經網絡
def __init__(self,n_feature,n_hidden,n_output):#神經網絡初始化,設置輸入曾參數,隱藏曾參數,輸出層參數
super(Net,self).__init__()#用super函數調用父類的通用初始化函數初始一下
self.hidden = torch.nn.Linear(n_feature,n_hidden)#設置隱藏層的輸入輸出參數,比如說輸入是n_feature,輸出是n_hidden
self.out = torch.nn.Linear(n_hidden,n_output)#同樣設置輸出層的輸入輸出參數
def forward(self,x):#前向計算過程
x = F.relu(self.hidden(x)) #樣本數據經過隱藏層然后被Relu函數掰彎!
x = self.out(x)經過輸出層返回
return x
net = Net(n_feature=2,n_hidden=10,n_output=2) #two classification has two n_features#實例化一個網絡結構
print(net)
optimizer = torch.optim.SGD(net.parameters(),lr=0.02) #設置優化器參數,lr=0.002指的是學習率的大小
loss_func = torch.nn.CrossEntropyLoss()#損失函數設置為loss_function
plt.ion()
for t in range(100):
out = net(x)#100次迭代輸出
loss = loss_func(out,y)#計算loss為out和y的差異
optimizer.zero_grad()#清除一下上次梯度計算的數值
loss.backward()#進行反向傳播
optimizer.step()#最優化迭代
if t%2 == 0:
plt.cla()
prediction = torch.max(out,1)[1] ##返回每一行中最大值的那個元素,且返回其索引 torch.max()[1], 只返回最大值的每個索引
pred_y = prediction.data.numpy().squeeze()
target_y = y.data.numpy()
plt.scatter(x.data.numpy()[:,0],x.data.numpy()[:,1],c=pred_y,s=100,lw=0,cmap='RdYlGn')
accuracy = float((pred_y == target_y).astype(int).sum())/float(target_y.size)
plt.text(1.5,-4,'Accuracy=%.2f'%accuracy,fontdict={'size':20,'color':'red'})
plt.pause(0.1)
plt.ioff()
plt.show()
最終運行出來的結果在下面:

