pytorch解決鳶尾花分類


半年前用numpy寫了個鳶尾花分類200行。。每一步計算都是手寫的  python構建bp神經網絡_鳶尾花分類

現在用pytorch簡單寫一遍,pytorch語法解釋請看上一篇pytorch搭建簡單網絡

 1 import pandas as pd
 2 import torch.nn as nn
 3 import torch
 4 
 5 
 6 class MyNet(nn.Module):
 7     def __init__(self):
 8         super(MyNet, self).__init__()
 9         self.fc = nn.Sequential(
10             nn.Linear(4, 3),
11             nn.Sigmoid(),
12             nn.Linear(3, 3),
13             nn.Sigmoid(),
14             nn.Linear(3, 1),
15         )
16         self.mls = nn.MSELoss()
17         self.opt = torch.optim.Adam(params=self.parameters(), lr=0.001)
18 
19     def get_data(self):
20         inputs = []
21         labels = []
22         with open('flower.csv') as file:
23             df = pd.read_csv(file, header=None)
24             x = df.iloc[:, 0:4].values
25             y = df.iloc[:, 4].values
26             for i in range(len(x)):
27                 inputs.append(x[i])
28             for j in range(len(y)):
29                 a = []
30                 a.append(y[j])
31                 labels.append(a)
32 
33         return inputs, labels
34 
35     def forward(self, inputs):
36         out = self.fc(inputs)
37         return out
38 
39     def train(self, x, label):
40         out = self.forward(x)
41         loss = self.mls(out, label)
42         self.opt.zero_grad()
43         loss.backward()
44         self.opt.step()
45 
46     def test(self, x):
47         return self.fc(x)
48 
49 
50 if __name__ == '__main__':
51     net = MyNet()
52     inputs, labels = net.get_data()
53     for i in range(1000):
54         for index, input in enumerate(inputs):
55             # 這里不加.float()會報錯,可能是數據格式的問題吧
56             input = torch.from_numpy(input).float()
57             label = torch.Tensor(labels[index])
58             net.train(input, label)
59     # 簡單測試一下
60     c = torch.Tensor([[5.6, 2.7, 4.2, 1.3]])
61     print(net.test(c))

運行結果趨近於0.5  正確,單純練一下pytorch,就沒有分訓練集,測試集

1 tensor([[0.5392]], grad_fn=<AddmmBackward>)

不用手寫反向傳播和梯度下降 是多么幸福一件事~


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM