最小二乘法擬合(python scipy)


行文思路:

  • 最小二乘法原理介紹
  • 利用 leastsq() 函數進行最小二乘法擬合
  • 擬合注意事項
  • 利用curve_fit 進行最小二乘法擬合
  • 總結:
  • 參考文獻
  • 實現代碼

一,最小二乘法擬合

最小二乘法是一種數學優化技術,它通過最小化誤差的平方和尋找數據的最佳函數匹配。優化是找到最小值或等式的數值解的問題。而線性回歸就是要求樣本回歸函數盡可能好地擬合目標函數值,也就是說,這條直線應該盡可能的處於樣本數據的中心位置。因此,選擇最佳擬合曲線的標准可以確定為:使總的擬合誤差(即總殘差)達到最小。

假設有一組實驗數據(xi,yi ), 事先知道它們之間應該滿足某函數關系yi=f(xi),通過這些已知信息,需要確定函數f的一些參數。例如,如果函數f是線性函數f(x)=kx+b, 那么參數 k和b就是需要確定的值。

如果用p表示函數中需要確定的參數,那么目標就是找到一組p,使得下面的函數S的值最小:

[公式]

當誤差最小的時候可以理解為此時的系數為最佳的擬合狀態。

scipy.optimization 子模塊提供了函數最小值(標量或多維)、曲線擬合和尋找等式的根的有用算法。在optimize模塊中可以使用 leastsq() 對數據進行最小二乘擬合計算。leastsq() 函數傳入誤差計算函數和初始值,該初始值將作為誤差計算函數的第一個參數傳入。計算的結果是一個包含兩個元素的元組,第一個元素是一個數組,表示擬合后的參數;第二個元素如果等於1、2、3、4中的其中一個整數,則擬合成功,否則將會返回 mesg。下面是官方的文檔介紹,只截取了主要的參數部分。

代碼實現:

1,導入模塊:

import numpy as np import matplotlib.pyplot as plt from scipy.optimize import leastsq

2,一元二次方程的參數擬合,首先創建擬合數據。

x = np.linspace(-10,10,100)           # 創建時間序列
p_value = [-2,5,10]                   # 原始數據的參數
noise = np.random.randn(len(x))       # 創建隨機噪聲
y = Fun(p_value,x)+noise*2            # 加上噪聲的序列

3,通過函數定義擬合函數的形式。

這里可以擬合任意的函數形式,這要能把它的表達式給出。

def Fun(p,x):                        # 定義擬合函數形式
    a1,a2,a3 = p
    return a1*x**2+a2*x+a3

4,定義殘差項。

一般最小二乘法是求擬合函數和目標函數差的平方,這里之所以沒有平方是應為在擬合函數的內部進行,這里不顯式的表示。

def error (p,x,y):                   # 擬合殘差
    return Fun(p,x)-y 

5, 進行擬合。

其中參數p0 為最小二乘法擬合的初值,初值的選取對於擬合時間和計算量影響很大,有事並對結果產生一定的影響。args() 中是除了初始值之外error() 中的所有參數的集合輸入。

para =leastsq(error, p0, args=(x,y))  # 進行擬合
y_fitted = Fun (para[0],x)            # 畫出擬合后的曲線

返回參數為一個包含擬合后參數的元組,可以通過中括號[] 取值的方式得到。

6,完整的代碼如下:

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import leastsq
 
def Fun(p,x):                        # 定義擬合函數形式
    a1,a2,a3 = p
    return a1*x**2+a2*x+a3
def error (p,x,y):                    # 擬合殘差
    return Fun(p,x)-y 
def main():
    x = np.linspace(-10,10,100)  # 創建時間序列
    p_value = [-2,5,10] # 原始數據的參數
    noise = np.random.randn(len(x))  # 創建隨機噪聲
    y = Fun(p_value,x)+noise*2 # 加上噪聲的序列
    p0 = [0.1,-0.01,100] # 擬合的初始參數設置
    para =leastsq(error, p0, args=(x,y)) # 進行擬合
    y_fitted = Fun (para[0],x) # 畫出擬合后的曲線
 
    plt.figure
    plt.plot(x,y,'r', label = 'Original curve')
    plt.plot(x,y_fitted,'-b', label ='Fitted curve')
    plt.legend()
    plt.show()
    print (para[0])
 
if __name__=='__main__':
   main()

最終擬合的參數結果:

[-1.99437662 5.03789895 10.00150115]

二, 使用curve_fit() 進行擬合

Note:使用 curve_fit(),主要的區別在於擬合函數的定義不同

def Fun(x, a1,a2,a3): # 定義擬合函數形式
    return a1*x**2+a2*x+a3

完整的代碼:

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
 
def Fun(x,a1,a2,a3):                   # 定義擬合函數形式
    return a1*x**2+a2*x+a3
def error (p,x,y): # 擬合殘差
 return Fun(p,x)-y
def main():
    x = np.linspace(-10,10,100)       # 創建時間序列
    a1,a2,a3 = [-2,5,10]              # 原始數據的參數
    noise = np.random.randn(len(x))   # 創建隨機噪聲
    y = Fun(x,a1,a2,a3)+noise*2       # 加上噪聲的序列
    para,pcov=curve_fit(Fun,x,y)
    y_fitted = Fun(x,para[0],para[1],para[2]) # 畫出擬合后的曲線
 
    plt.figure
    plt.plot(x,y,'r', label = 'Original curve')
    plt.plot(x,y_fitted,'-b', label ='Fitted curve')
    plt.legend()
    plt.show()
    print (para)
 
if __name__=='__main__':
   main()

擬合結果

最終的擬合結果參數為:

[-2.00309373 5.00945061 10.30565526]

三, 多項式擬合

代碼實現:

def main():
    x = np.linspace(-10,10,100) # 創建時間序列
    a1,a2,a3 = [-2,5,10] # 原始數據的參數
    noise = np.random.randn(len(x)) # 創建隨機噪聲
    y = Fun(x,a1,a2,a3)+noise*2 # 加上噪聲的序列
    plt.plot(x,y)
    para=np.polyfit(x, y, deg = 2)
 
    y_fitted = Fun(x,para[0],para[1],para[2])
    plt.figure
    plt.plot(x,y,'ro', label = 'Original curve')
    plt.plot(x,y_fitted,'-b', label ='Fitted curve')
    plt.legend()
    plt.show()
    print(para)

if __name__=='__main__':
    main() 

擬合結果為:

[-2.00532192 5.01626878 10.07612899]

總結:

本文主要講了最小二乘法擬合曲線的實現方法,使用 leastsq() 和 curve_fit(),最后講解了多項式的擬合poly.fit(). 最小二乘法的兩個擬合大體的步驟是一樣的,定義擬合范式,傳入擬合參數,開始擬合得出擬合結果。對於簡單的擬合函數兩者的差別很小,但是復雜的,需要具體的分析。文章還會繼續的分析擬合結果的含義,讓我們對擬合的結果有更加透徹的理解,隨心擬合。

參考文獻:

SciPy v1.3.0 Reference Guide

SciPy v0.19.1 Reference Guide

numpy.polyfit - NumPy v1.16 Manual


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM