PyTorch 實現異或XOR運算


1. 異或運算

 

 2. 實現

 
         
 1 # 利用Pytorch解決XOR問題
 2 import torch
 3 import torch.nn as nn
 4 import torch.nn.functional as F
 5 import torch.optim as optim
 6 import numpy as np
 7 
 8 data = np.array([[1, 0, 1], [0, 1, 1],
 9                  [1, 1, 0], [0, 0, 0]], dtype='float32')
10 x = data[:, :2]
11 y = data[:, 2]
12 
13 
14 # 初始化權重變量
15 def weight_init_normal(m):
16     classname = m.__class__.__name__ #是獲取類名,得到的結果classname是一個字符串
17     if classname.find('Linear') != -1:  #判斷這個類名中,是否包含"Linear"這個字符串,字符串的find()方法,檢索這個字符串中是否包含另一個字符串
18         m.weight.data.normal_(0.0, 1.)
19         m.bias.data.fill_(0.)
20 
21 
22 class XOR(nn.Module):
23     def __init__(self):
24         super(XOR, self).__init__()
25         self.fc1 = nn.Linear(2, 3)   # 隱藏層 3個神經元
26         self.fc2 = nn.Linear(3, 4)   # 隱藏層 4個神經元
27         self.fc3 = nn.Linear(4, 1)   # 輸出層 1個神經元
28 
29     def forward(self, x):
30         h1 = F.sigmoid(self.fc1(x))  # 之前也嘗試過用ReLU作為激活函數, 太容易死亡ReLU了.
31         h2 = F.sigmoid(self.fc2(h1))
32         h3 = F.sigmoid(self.fc3(h2))
33         return h3
34 
35 
36 net = XOR()
37 net.apply(weight_init_normal) #相當於net.weight_init_normal()
38  #apply方式的調用是遞歸的,即net這個類和其子類(如果有),挨個調用一次weight_init_normal()方法。
39 x = torch.Tensor(x.reshape(-1, 2))
40 y = torch.Tensor(y.reshape(-1, 1))
41 
42 # 定義loss function
43 criterion = nn.BCELoss()  # MSE
44 # 定義優化器
45 optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)  # SGD
46 # 訓練
47 for epoch in range(500):
48     optimizer.zero_grad()   # 清零梯度緩存區
49     out = net(x)
50     loss = criterion(out, y)
51     print(loss)
52     loss.backward()
53     optimizer.step()  # 更新
54 
55 # 測試
56 test = net(x)
57 print("input is {}".format(x.detach().numpy()))
58 print('out is {}'.format(test.detach().numpy()))

來源:(1條消息) PyTorch——解決異或問題XOR_我是大黃同學呀的博客-CSDN博客_pytorch 異或

稍微改了一下網絡結構,添加少量注釋,理解第16-17,37行。

結構如下:

 

 A Neural Network Playground (tensorflow.org)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM