Python:使用piecewise與curve_fit進行三段擬合


x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15,16,17,18,19,20,21], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03,145,147,149,151,153,155])

plt.scatter(x,y,s=30,c='b')

得到如下散點圖:

 

定義分段函數

#6個未知參數 x0x1,y0,y1分別是2個分割間斷點的橫縱坐標 k0,k1是第一和第三段直線的斜率
def piecewise(x,x0,x1,y0,y1,k0,k1):
    return np.piecewise(x , [x <= x0, np.logical_and(x0<x, x<= x1),x>x1] ,
                        [lambda x:k0*(x-x0) + y0,#根據點斜式構建函數
                         lambda x:(x-x0)*(y1-y0)/(x1-x0)+y0,#根據兩點式構建函數
                        lambda x:k1*(x-x1) + y1])

 

根據分段函數進行擬合,通過迭代尋找最優的p,即為p_best

注:p(p_best)中包含的是擬合之后求得的所有未知參數

perr_min = np.inf
p_best = None
for n in range(100):
    k = np.random.rand(6)*20
    p , e = optimize.curve_fit(piecewise, x, y,p0=k)
    perr = np.sum(np.abs(y-piecewise(x, *p)))
    if(perr < perr_min):
        perr_min = perr
        p_best = p

 

根據p_best調用curve_fit函數繪制擬合圖像

xd = np.linspace(0, 21, 100)
plt.figure()

plt.plot(xd, piecewise(xd, *p_best))
xx=(p_best[0],p_best[1])
yy=(p_best[2],p_best[3])

plt.scatter(xx,yy,s=30,c='black')
plt.show()

結果如下:

 

完整代碼:

from scipy import optimize
import matplotlib.pyplot as plt
import numpy as np

#6個未知參數 x0x1,y0,y1分別是2個分割間斷點的橫縱坐標 k0,k1是第一和第三段直線的斜率
def piecewise(x,x0,x1,y0,y1,k0,k1):
    return np.piecewise(x , [x <= x0, np.logical_and(x0<x, x<= x1),x>x1] ,
                        [lambda x:k0*(x-x0) + y0,#根據點斜式構建函數
                         lambda x:(x-x0)*(y1-y0)/(x1-x0)+y0,#根據兩點式構建函數
                        lambda x:k1*(x-x1) + y1])

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15,16,17,18,19,20,21], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03,145,147,149,151,153,155])

plt.scatter(x,y,s=30,c='b')

perr_min = np.inf
p_best = None
for n in range(100):
    k = np.random.rand(6)*20
    p , e = optimize.curve_fit(piecewise, x, y,p0=k)
    perr = np.sum(np.abs(y-piecewise(x, *p)))
    if(perr < perr_min):
        perr_min = perr
        p_best = p

xd = np.linspace(0, 21, 100)
plt.figure()


plt.plot(xd, piecewise(xd, *p_best))
xx=(p_best[0],p_best[1])
yy=(p_best[2],p_best[3])

plt.scatter(xx,yy,s=30,c='black')
plt.show()

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM