人們都說Chainer是一塊非常靈活you要用的框架,今天接着項目里面的應用,初步接觸一下,漲漲姿勢,直接上源碼吧,看着好理解。其實跟Tensorflow等其他框架都是一個套路,個人感覺更簡潔了。
1 """ 2 測試使用 3 """ 4 import pickle 5 import time 6 import numpy as np 7 import matplotlib.pyplot as plt 8 from chainer import Chain, Variable, optimizers, serializers 9 import chainer.functions as F 10 import chainer.links as L 11 12 # 創建Chainer Variables變]量 13 a = Variable(np.array([3], dtype=np.float32)) 14 b = Variable(np.array([4], dtype=np.float32)) 15 c = a**2 +b**2 16 17 # 5通過data屬性檢查之前定義的變量 18 print('a.data:{0}, b.data{1}, c.data{2}'.format(a.data, b.data, c.data)) 19 20 # 使用backward()方法,對變量c進行反向傳播.對c進行求導 21 c.backward() 22 # 通過在變量中存儲的grad屬性,檢查其導數 23 print('dc/da = {0}, dc/db={1}, dc/dc={2}'.format(a.grad, b.grad, c.grad)) 24 25 # 在chainer中做線性回歸 26 x = 30*np.random.rand(1000).astype(np.float32) 27 y = 7*x + 10 28 y += 10*np.random.randn(1000).astype(np.float32) 29 30 plt.scatter(x, y) 31 plt.xlabel('x') 32 plt.ylabel('y') 33 plt.show() 34 35 36 # 使用chainer做線性回歸 37 38 # 從一個變量到另一個變量建立一個線性連接 39 linear_function = L.Linear(1, 1) 40 # 設置x和y作為chainer變量,以確保能夠變形到特定形態 41 x_var = Variable(x.reshape(1000, -1)) 42 y_var = Variable(y.reshape(1000, -1)) 43 # 建立優化器 44 optimizer = optimizers.MomentumSGD(lr=0.001) 45 optimizer.setup(linear_function) 46 47 48 # 定義一個前向傳播函數,數據作為輸入,線性函數作為輸出 49 def linear_forward(data): 50 return linear_function(data) 51 52 53 # 定義一個訓練函數,給定輸入數據,目標數據,迭代數 54 def linear_train(train_data, train_traget, n_epochs=200): 55 for _ in range(n_epochs): 56 # 得到前向傳播結果 57 output = linear_forward(train_data) 58 # 計算訓練目標數據和實際標數據的損失 59 loss = F.mean_squared_error(train_traget, output) 60 # 在更新之前將梯度取零,線性函數和梯度有非常密切的關系 61 # linear_function.zerograds() 62 linear_function.cleargrads() 63 # 計算並更新所有梯度 64 loss.backward() 65 # 優化器更新 66 optimizer.update() 67 68 69 # 繪制訓練結果 70 plt.scatter(x, y, alpha=0.5) 71 for i in range(150): 72 # 訓練 73 linear_train(x_var, y_var, n_epochs=5) 74 # 預測值 75 y_pred = linear_forward(x_var).data 76 plt.plot(x, y_pred, color=plt.cm.cool(i / 150.), alpha=0.4, lw=3) 77 78 slope = linear_function.W.data[0, 0] # linear_function是之前定義的連接,線性連接有兩個參數W和b,此種形式可以獲取訓練后參數的值,slope是斜率的意思 79 intercept = linear_function.b.data[0] # intercept是截距的意思 80 plt.title("Final Line: {0:.3}x + {1:.3}".format(slope, intercept)) 81 plt.xlabel('x') 82 plt.ylabel('y') 83 plt.show()
