pyhton scipy最小二乘法(scipy.linalg.lstsq模塊)


最小二乘法則是一種統計學習優化技術,它的目標是最小化誤差平方之和來作為目標J(θ)J(θ),從而找到最優模型。

 

7. SciPy最小二乘法

最小二乘法則是一種統計學習優化技術,它的目標是最小化誤差平方之和來作為目標J(θ),從而找到最優模型。

1、線性最小二乘法

假設真實的模型是y=2x+1,我們有一組數據(xi,yi)共100個,看能否基於這100個數據找出xiyi的線性關系方程y=2x+1?我們可以通過以下幾步來完成。

1).首先是通過程序構造出100個(xi,yi)數據。

xi = x + np.random.normal(0, 0.05, 100)

yi = 1 + 2 * xi + np.random.normal(0, 0.05, 100)

2).接下來給出模型f(x)=a+bx的矩陣A,由於有100個觀測(xi,yi)的數據,那么就有:

將以上式子寫成如下矩陣的形式:

 

A = np.vstack([xi**0, xi**1]) 

AT即100×2的那個矩陣

3).調用scipy.linalg.lstsq傳入AT和觀測值里的yii即程序里的yi變量即可求得f(x)=a+bx里的a和b。a和b記錄在lstsq函數的第一個返回值里。

sol, r, rank, s = la.lstsq(A.T, yi)

4). scipy.linalg.lstsq的第一個返回值sol共有兩個值,sol[0]即是估計出來的f(x)=a+bx里a,sol[1]代表f(x)=a+bx里b。因此f(x)為:

y_fit = sol[0] + sol[1] * x

至此找到了這100個(xi,yi)的模型方程。從print sol語句的輸出結果可以看出數據還是比較接近y=2x+1的。

完整的代碼如下所示:

import scipy.linalg as la
import numpy as np
import matplotlib.pyplot as plt
m = 100
x = np.linspace(-1, 1, m)
y_exact = 1 + 2 * x
xi = x + np.random.normal(0, 0.05, 100)
yi = 1 + 2 * xi + np.random.normal(0, 0.05, 100)
A = np.vstack([xi**0, xi**1])
sol, r, rank, s = la.lstsq(A.T, yi) #求取各個系數大小
y_fit = sol[0] + sol[1] * x
fig, ax = plt.subplots(figsize=(12, 8))
ax.plot(xi, yi, 'go', alpha=0.5, label='Simulated data')
ax.plot(x, y_exact, 'k', lw=2, label='True value y = 1 + 2x')
ax.plot(x, y_fit, 'b', lw=2, label='Least square fit')
ax.set_xlabel("x", fontsize=18)
ax.set_ylabel(”y", fontsize=18)
ax.legend(loc=2) #設置曲線標注位置
plt.show()
2、二次函數最小二乘法
這個程序和上面的程序差不多,只不過模型變成了f(xi)=a+bx+cx2f(xi)=a+bx+cx2了而已,請自己分析分析。
完整程序如下:
import scipy.linalg as la
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-1, 1, 100)
a, b, c = 1, 2, 3
y_exact = a + b * x + c * x**2
m = 100
xi=1 - 2 * np.random.rand(m)
yi=a + b * xi + c * xi**2 + np.random.randn(m)
A = np.vstack([xi**0, xi**1, xi**2])
sol, r, rank, s = la.lstsq(A.T, yi)
y_fit = sol[0] + sol[1] * x + sol[2] * x**2
fig, ax = plt.subplots(figsize=(12, 4))
ax.plot(xi, yi, 'go', alpha=0.5, label='Simulated data')
ax.plot(x, y_exact, 'k', lw=2, label='True value $y = 1 + 2x + 3x^2$')
ax.plot(x, y_fit, 'b', lw=2, label='Least square fit')
ax.set_xlabel("x", fontsize=18)
ax.set_ylabel("y", fontsize=18)
ax.legend(loc=2)
plt.show()
具體結果展示如下:
 
          
 

 

 


 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM