''' 多項式回歸:若希望回歸模型更好的擬合訓練樣本數據,可以使用多項式回歸器。 一元多項式回歸: 數學模型:y = w0 + w1 * x^1 + w2 * x^2 + .... + wn * x^n 將高次項看做對一次項特征的擴展得到: y = w0 + w1 * x1 + w2 * x2 + .... + wn * xn 那么一元多項式回歸即可以看做為多元線性回歸,可以使用LinearRegression模型對樣本數據進行模型訓練。 所以一元多項式回歸的實現需要兩個步驟: 1. 將一元多項式回歸問題轉換為多元線性回歸問題(只需給出多項式最高次數即可)。 2. 將1步驟得到多項式的結果中 w1,w2,w3,...,wn當做樣本特征,交給線性回歸器訓練多元線性模型。 選擇合適的最高次數其模型R2評分會高於一元線性回歸模型評分,如果次數過高,會出現過擬合現象,評分會低於一元線性回歸評分 使用sklearn提供的"數據管線"實現兩個步驟的順序執行: import sklearn.pipeline as pl import sklearn.preprocessing as sp import sklearn.linear_model as lm model = pl.make_pipeline( # 10: 多項式的最高次數 sp.PolynomialFeatures(10), # 多項式特征擴展器 lm.LinearRegression()) # 線性回歸器 過擬合和欠擬合: 1.過擬合:過於復雜的模型,對於訓練數據可以得到較高的預測精度,但對於測試數據通常精度較低,這種現象叫做過擬合。 2.欠擬合:過於簡單的模型,無論對於訓練數據還是測試數據都無法給出足夠高的預測精度,這種現象叫做欠擬合。 3.一個性能可以接受的學習模型應該對訓練數據和測試數據都有接近的預測精度,而且精度不能太低。 訓練集R2 測試集R2 0.3 0.4 欠擬合:過於簡單,無法反映數據的規則 0.9 0.2 過擬合:過於復雜,太特殊,缺乏一般性 0.7 0.6 可接受:復雜度適中,既反映數據的規則,同時又不失一般性 加載single.txt文件中的數據,基於一元多項式回歸算法訓練回歸模型。 步驟: 導包--->讀取數據--->創建多項式回歸模型--->模型訓練及預測--->通過模型預測得到pred_y,繪制多項式函數圖像 ''' import sklearn.pipeline as pl import sklearn.linear_model as lm import sklearn.preprocessing as sp import matplotlib.pyplot as mp import numpy as np import sklearn.metrics as sm # 采集數據 x, y = np.loadtxt('./ml_data/single.txt', delimiter=',', usecols=(0, 1), unpack=True) # 把輸入變為二維數組,一行一樣本,一列一特征 x = x.reshape(-1, 1) # 創建模型 model = pl.make_pipeline( sp.PolynomialFeatures(10), # 多項式特征拓展器 lm.LinearRegression() # 線性回歸器 ) # 訓練模型 model.fit(x, y) # 求預測值y pred_y = model.predict(x) # 模型評估 print('平均絕對值誤差:', sm.mean_absolute_error(y, pred_y)) print('平均平方誤差:', sm.mean_squared_error(y, pred_y)) print('中位絕對值誤差:', sm.median_absolute_error(y, pred_y)) print('R2得分:', sm.r2_score(y, pred_y)) # 繪制多項式回歸線 px = np.linspace(x.min(), x.max(), 1000) px = px.reshape(-1, 1) pred_py = model.predict(px) # 繪制圖像 mp.figure("Poly Regression", facecolor='lightgray') mp.title('Poly Regression', fontsize=16) mp.tick_params(labelsize=10) mp.grid(linestyle=':') mp.xlabel('x') mp.ylabel('y') mp.scatter(x, y, s=60, marker='o', c='dodgerblue', label='Points') mp.plot(px, pred_py, c='orangered', label='PolyFit Line') mp.tight_layout() mp.legend() mp.show() 輸出結果: 平均絕對值誤差: 0.4818952136579405 平均平方誤差: 0.35240714067500095 中位絕對值誤差: 0.47265950409692536 R2得分: 0.7868629092058499