Chainer的初步學習


  人們都說Chainer是一塊非常靈活you要用的框架,今天接着項目里面的應用,初步接觸一下,漲漲姿勢,直接上源碼吧,看着好理解。其實跟Tensorflow等其他框架都是一個套路,個人感覺更簡潔了。

 

 1 """
 2     測試使用
 3 """
 4 import pickle
 5 import time
 6 import numpy as np
 7 import matplotlib.pyplot as plt
 8 from chainer import Chain, Variable, optimizers, serializers
 9 import chainer.functions as F
10 import chainer.links as L
11 
12 # 創建Chainer Variables變]量
13 a = Variable(np.array([3], dtype=np.float32))
14 b = Variable(np.array([4], dtype=np.float32))
15 c = a**2 +b**2
16 
17 # 5通過data屬性檢查之前定義的變量
18 print('a.data:{0}, b.data{1}, c.data{2}'.format(a.data, b.data, c.data))
19 
20 # 使用backward()方法,對變量c進行反向傳播.對c進行求導
21 c.backward()
22 # 通過在變量中存儲的grad屬性,檢查其導數
23 print('dc/da = {0}, dc/db={1}, dc/dc={2}'.format(a.grad, b.grad, c.grad))
24 
25 # 在chainer中做線性回歸
26 x = 30*np.random.rand(1000).astype(np.float32)
27 y = 7*x + 10
28 y += 10*np.random.randn(1000).astype(np.float32)
29 
30 plt.scatter(x, y)
31 plt.xlabel('x')
32 plt.ylabel('y')
33 plt.show()
34 
35 
36 # 使用chainer做線性回歸
37 
38 # 從一個變量到另一個變量建立一個線性連接
39 linear_function = L.Linear(1, 1)
40 # 設置x和y作為chainer變量,以確保能夠變形到特定形態
41 x_var = Variable(x.reshape(1000, -1))
42 y_var = Variable(y.reshape(1000, -1))
43 # 建立優化器
44 optimizer = optimizers.MomentumSGD(lr=0.001)
45 optimizer.setup(linear_function)
46 
47 
48 # 定義一個前向傳播函數,數據作為輸入,線性函數作為輸出
49 def linear_forward(data):
50     return linear_function(data)
51 
52 
53 # 定義一個訓練函數,給定輸入數據,目標數據,迭代數
54 def linear_train(train_data, train_traget, n_epochs=200):
55     for _ in range(n_epochs):
56         # 得到前向傳播結果
57         output = linear_forward(train_data)
58         # 計算訓練目標數據和實際標數據的損失
59         loss = F.mean_squared_error(train_traget, output)
60         # 在更新之前將梯度取零,線性函數和梯度有非常密切的關系
61         # linear_function.zerograds()
62         linear_function.cleargrads()
63         # 計算並更新所有梯度
64         loss.backward()
65         # 優化器更新
66         optimizer.update()
67 
68 
69 # 繪制訓練結果
70 plt.scatter(x, y, alpha=0.5)
71 for i in range(150):
72     # 訓練
73     linear_train(x_var, y_var, n_epochs=5)
74     # 預測值
75     y_pred = linear_forward(x_var).data
76     plt.plot(x, y_pred, color=plt.cm.cool(i / 150.), alpha=0.4, lw=3)
77 
78 slope = linear_function.W.data[0, 0]        # linear_function是之前定義的連接,線性連接有兩個參數W和b,此種形式可以獲取訓練后參數的值,slope是斜率的意思
79 intercept = linear_function.b.data[0]       # intercept是截距的意思
80 plt.title("Final Line: {0:.3}x + {1:.3}".format(slope, intercept))
81 plt.xlabel('x')
82 plt.ylabel('y')
83 plt.show()

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM