Python 梯度下降法


題目描述:
自定義一個可微並且存在最小值的一元函數,用梯度下降法求其最小值。並繪制出學習率從0.1到0.9(步長0.1)時,達到最小值時所迭代的次數的關系曲線,根據該曲線給出簡單的分析。

代碼:

# -*- coding: utf-8 -*-
"""
Created on Tue Jun  4 10:19:03 2019

@author: Administrator
"""

import numpy as np
import matplotlib.pyplot as plt
plot_x=np.linspace(-1,6,150)   #在-1到6之間等距的生成150個數
plot_y=(plot_x-2.5)**2+3	   # 同時根據plot_x來生成plot_y(y=(x-2.5)²+3)

plt.plot(plot_x,plot_y)
plt.show()

###定義一個求二次函數導數的函數dJ
def dJ(x):
    return 2*(x-2.5)

###定義一個求函數值的函數J
def J(x):
    try:
        return (x-2.5)**2+3
    except:
        return float('inf')

x=0.0							#隨機選取一個起始點
eta=0.1						    #eta是學習率,用來控制步長的大小
epsilon=1e-8				    #用來判斷是否到達二次函數的最小值點的條件
history_x=[x]                   #用來記錄使用梯度下降法走過的點的X坐標
count=0
min=0
while True:
    gradient=dJ(x)				#梯度(導數)
    last_x=x
    x=x-eta*gradient
    history_x.append(x)
    count=count+1
    if (abs(J(last_x)-J(x)) <epsilon):		#用來判斷是否逼近最低點
        min=x
        break
    
plt.plot(plot_x,plot_y)     
plt.plot(np.array(history_x),J(np.array(history_x)),color='r',marker='*')   #繪制x的軌跡
plt.show()

print'min_x =',(min)
print'min_y =',(J(min))	        #打印到達最低點時y的值
print'count =',(count)

sum_eta=[]
result=[]
for i in range(1,10,1):
    x=0.0							#隨機選取一個起始點
    eta=i*0.1
    sum_eta.append(eta)
    epsilon=1e-8				    #用來判斷是否到達二次函數的最小值點的條件
    num=0
    min=0
    while True:
        gradient=dJ(x)				#梯度(導數)
        last_x=x
        x=x-eta*gradient
        num=num+1
        if (abs(J(last_x)-J(x)) <epsilon):		#用來判斷是否逼近最低點
            min=x
            break
    
    result.append(num)#記錄學習率從0.1到0.9(步長0.1)時,達到最小值時所迭代的次數

plt.scatter(sum_eta,result,c='r')
plt.plot(sum_eta,result,c='r')
plt.title("relation")
plt.xlabel("eta")
plt.ylabel("count")
plt.show
print(result)

  

運行結果:

 

 

結果分析:
函數y=(x-2.5)²+3從學習率和迭代次數的關系圖上我們可以知道當學習率較低時迭代次數較多,隨着學習率的增大,迭代次數開始逐漸減少,當學習率為0.5時迭代次數最少,之后隨着學習率的增加,迭代次數開始增加,當學習率為0.9時迭代次數和0.1時相等。關於0.5成對稱分布。


原文:https://blog.csdn.net/Ferryman23333/article/details/91050219

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM