批梯度下降和隨機梯度下降存在着一定的差異,主要是在theta的更新上,批量梯度下降使用的是將所有的樣本都一批次的引入到theta的計算中,而隨機梯度下降在更新theta時只是隨機選擇所有樣本中的一個,然后對theta求導,所以隨機梯度下降具有較快的速度,但是可能陷入局部最優解
以下是代碼實現:
# coding:utf-8 import numpy as np import random def BGD(x, y, theta, alpha, m, max_iteration): """ 批量梯度下降法:batch_Gradient_Descent :param x:train_data :param y:train_label :param theta:初始化權重 :param alpha:學習速率 :param m: :param max_iteration:迭代次數 :return: """ x_train = x.transpose() for i in range(0, max_iteration): hypothesis = np.dot(x, theta) # 神經網絡輸出 loss = hypothesis - y # 損失函數 # 下降梯度 gradient = np.dot(x_train, loss) / m # 求導之后得到theta theta = theta - alpha * gradient return theta def SGD(x, y, theta, alpha, m, max_iteration): """ 隨機梯度下降法:stochastic_Gradient_Descent :param x:train_data :param y:train_label :param theta:初始化權重 :param alpha:學習速率 :param m: :param max_iteration:迭代次數 :return: """ data = list(range(4)) for i in range(0, max_iteration): hypothesis = np.dot(x, theta) loss = hypothesis - y # 損失函數 index = random.sample(data, 1)[0] # 從data列表中隨機選取一個數 # 下降梯度 gradient = loss[index] * x[index] # 求導之后得到theta theta = theta - alpha * gradient return theta # 以下寫法也可以,這種運算量較少 # data = list(range(4)) # # for i in range(0, max_iteration): # # index = random.sample(data, 1)[0] # 從data列表中隨機選取一個數 # # hypothesis = np.dot(x[index], theta) # 計算神經網絡的輸出(無激活函數) # loss = hypothesis - y[index] # 損失函數 # # 下降梯度 # gradient = loss * x[index] # # # 求導之后得到theta # theta = theta - alpha * gradient # return theta def main(): train_data = np.array([[1, 4, 2], [2, 5, 3], [5, 1, 6], [4, 2, 8]]) train_label = np.array([19, 26, 19, 20]) m, n = np.shape(train_data) # 讀取矩陣的長度,shape[0]就是讀取矩陣第一維度的長度 # 初始化權重都為1 theta = np.ones(n) # ones()函數用以創建指定形狀和類型的數組,默認情況下返回的類型是float64 max_iteration = 500 # 迭代次數 alpha = 0.01 # 學習速率 # -------------------------------------------------------------------------------- theta1 = BGD(train_data, train_label, theta, alpha, m, max_iteration) print(theta1) theta2 = SGD(train_data, train_label, theta, alpha, m, max_iteration) print(theta2) if __name__ == "__main__": main()
輸出結果: