注:在上一篇的一般線性回歸中,使用的假設函數是一元一次方程,也就是二維平面上的一條直線。但是很多時候可能會遇到直線方程無法很好的擬合數據的情況,這個時候可以嘗試使用多項式回歸。多項式回歸中,加入了特征的更高次方(例如平方項或立方項),也相當於增加了模型的自由度,用來捕獲數據中非線性的變化。添加高階項的時候,也增加了模型的復雜度。隨着模型復雜度的升高,模型的容量以及擬合數據的能力增加,可以進一步降低訓練誤差,但導致過擬合的風險也隨之增加。
圖A,模型復雜度與訓練誤差及測試誤差之間的關系
0. 多項式回歸的一般形式
在多項式回歸中,最重要的參數是最高次方的次數。設最高次方的次數為$n$,且只有一個特征時,其多項式回歸的方程為:
$$ \hat{h} = \theta_0 + \theta_1 x^1 + \ ... \ + \theta_{n-1} x^{n-1} + \theta_n x^n $$
如果令$x_0 = 1$,在多樣本的情況下,可以寫成向量化的形式:
$$\hat{h} = X \cdot \theta$$
其中$X$是大小為$m \cdot (n+1)$的矩陣,$\theta$是大小為$(n+1) \cdot 1$的矩陣。在這里雖然只有一個特征$x$以及$x$的不同次方,但是也可以將$x$的高次方當做一個新特征。與多元回歸分析唯一不同的是,這些特征之間是高度相關的,而不是通常要求的那樣是相互對立的。
在這里有個問題在剛開始學習線性回歸的時候困擾了自己很久:如果假設中出現了高階項,那么這個模型還是線性模型嗎?此時看待問題的角度不同,得到的結果也不同。如果把上面的假設看成是特征$x$的方程,那么該方程就是非線性方程;如果看成是參數$\theta$的方程,那么$x$的高階項都可以看做是對應$\theta$的參數,那么該方程就是線性方程。很明顯,在線性回歸中采用了后一種解釋方式。因此多項式回歸仍然是參數的線性模型。
1. 多項式回歸的實現
下面主要使用了numpy、scipy、matplotlib和scikit-learn,所有使用到的函數的導入如下:
1 import numpy as np 2 from scipy import stats 3 import matplotlib.pyplot as plt 4 from sklearn.preprocessing import PolynomialFeatures 5 from sklearn.linear_model import LinearRegression 6 from sklearn.metrics import mean_squared_error
下是使用的數據是使用$y = x^2 + 2$並加入一些隨機誤差生成的,只取了10個數據點:
1 data = np.array([[ -2.95507616, 10.94533252], 2 [ -0.44226119, 2.96705822], 3 [ -2.13294087, 6.57336839], 4 [ 1.84990823, 5.44244467], 5 [ 0.35139795, 2.83533936], 6 [ -1.77443098, 5.6800407 ], 7 [ -1.8657203 , 6.34470814], 8 [ 1.61526823, 4.77833358], 9 [ -2.38043687, 8.51887713], 10 [ -1.40513866, 4.18262786]]) 11 m = data.shape[0] # 樣本大小 12 X = data[:, 0].reshape(-1, 1) # 將array轉換成矩陣 13 y = data[:, 1].reshape(-1, 1) 14 plt.plot(X, y, "b.") 15 plt.xlabel('X') 16 plt.ylabel('y') 17 plt.show()
這些數據點plot出來,如下圖:
圖1-1,原始數據
1.1 直線方程擬合
下面先用直線方程擬合上面的數據點:
1 lin_reg = LinearRegression() 2 lin_reg.fit(X, y) 3 print(lin_reg.intercept_, lin_reg.coef_) # [ 4.97857827] [[-0.92810463]] 4 5 X_plot = np.linspace(-3, 3, 1000).reshape(-1, 1) 6 y_plot = np.dot(X_plot, lin_reg.coef_.T) + lin_reg.intercept_ 7 plt.plot(X_plot, y_plot, 'r-') 8 plt.plot(X, y, 'b.') 9 plt.xlabel('X') 10 plt.ylabel('y') 11 plt.savefig('regu-2.png', dpi=200)
圖1-2,直線擬合的效果
可以使用函數"mean_squared_error"來計算誤差(使用前面介紹過的Mean squared error, MSE):
h = np.dot(X.reshape(-1, 1), lin_reg.coef_.T) + lin_reg.intercept_ print(mean_squared_error(h, y)) # 3.34
1.2 使用多項式方程
為了擬合2次方程,需要有特征$x^2$的數據,這里可以使用函數"PolynomialFeatures"來獲得:
1 poly_features = PolynomialFeatures(degree=2, include_bias=False) 2 X_poly = poly_features.fit_transform(X) 3 print(X_poly)
結果如下:
[[-2.95507616 8.73247511] [-0.44226119 0.19559496] [-2.13294087 4.54943675] [ 1.84990823 3.42216046] [ 0.35139795 0.12348052] [-1.77443098 3.1486053 ] [-1.8657203 3.48091224] [ 1.61526823 2.60909145] [-2.38043687 5.66647969] [-1.40513866 1.97441465]]
利用上面的數據做線性回歸分析:
1 lin_reg = LinearRegression() 2 lin_reg.fit(X_poly, y) 3 print(lin_reg.intercept_, lin_reg.coef_) # [ 2.60996757] [[-0.12759678 0.9144504 ]] 4 5 X_plot = np.linspace(-3, 3, 1000).reshape(-1, 1) 6 X_plot_poly = poly_features.fit_transform(X_plot) 7 y_plot = np.dot(X_plot_poly, lin_reg.coef_.T) + lin_reg.intercept_ 8 plt.plot(X_plot, y_plot, 'r-') 9 plt.plot(X, y, 'b.') 10 plt.show()
第3行得到了訓練后的參數,即多項式方程為$h = -0.13x + 0.91x^2 + 2.61$ (結果中系數的順序與$X$中特征的順序一致),如下圖所示:
圖1-3:2次多項式方程與原始數據的比較
利用多項式回歸,代價函數MSE的值下降到了0.07。通過觀察代碼,可以發現訓練多項式方程與直線方程唯一的差別是輸入的訓練集$X$的差別。在訓練直線方程時直接輸入了$X$的值,在訓練多項式方程的時候,還添加了我們計算出來的$x^2$這個“新特征”的值(由於$x^2$完全是由$x$的值確定的,因此嚴格意義上來講此時該模型只有一個特征$x$)。
此時有個非常有趣的問題:假如一開始得到的數據就是上面代碼中"X_poly"的樣子,且不知道$x_1$與$x_2$之間的關系。此時相當於我們有10個樣本,每個樣本具有$x_1, x_2$兩個不同的特征。這時假設函數為:$$\hat{h} = \theta_0 + \theta_1 x_1 + \theta_2 x_2$$
直接按照二元線性回歸方程來訓練,也可以得到上面同樣的結果($\theta$的值)。如果在相同情況下,收集到了新的數據,可以直接帶入上面的方程進行預測。唯一不同的是,我們不知道$x_2 = x_1^2$這個隱含在數據內部的關系,所有也就無法畫出圖1-3中的這條曲線。一旦了解到了這兩個特征之間的關系,數據的維度就從3維下降到了2維(包含截距項$\theta_0$)。
2. 持續降低訓練誤差與過擬合
在上面實現多項式回歸的過程中,通過引入高階項$x^2$,訓練誤差從3.34下降到了0.07,減小了將近50倍。那么訓練誤差是否還有進一步下降的空間呢?答案是肯定的,通過繼續增加更高階的項,訓練誤差可以進一步降低。通過嘗試,當最高階項為$x^{11}$時,訓練誤差為3.11e-23,幾乎等於0了。
下面是測試不同degree的過程:
1 # test different degree and return loss 2 def try_degree(degree, X, y): 3 poly_features_d = PolynomialFeatures(degree=degree, include_bias=False) 4 X_poly_d = poly_features_d.fit_transform(X) 5 lin_reg_d = LinearRegression() 6 lin_reg_d.fit(X_poly_d, y) 7 return {'X_poly': X_poly_d, 'intercept': lin_reg_d.intercept_, 'coef': lin_reg_d.coef_} 8 9 degree2loss_paras = [] 10 for i in range(2, 20): 11 paras = try_degree(i, X, y) 12 h = np.dot(paras['X_poly'], paras['coef'].T) + paras['intercept'] 13 _loss = mean_squared_error(h, y) 14 degree2loss_paras.append({'d': i, 'loss': _loss, 'coef': paras['coef'], 'intercept': paras['intercept']}) 15 16 min_index = np.argmin(np.array([i['loss'] for i in degree2loss_paras])) 17 min_loss_para = degree2loss_paras[min_index] 18 print(min_loss_para) # 19 X_plot = np.linspace(-3, 1.9, 1000).reshape(-1, 1) 20 poly_features_d = PolynomialFeatures(degree=min_loss_para['d'], include_bias=False) 21 X_plot_poly = poly_features_d.fit_transform(X_plot) 22 y_plot = np.dot(X_plot_poly, min_loss_para['coef'].T) + min_loss_para['intercept'] 23 fig, ax = plt.subplots(1, 1) 24 ax.plot(X_plot, y_plot, 'r-', label='degree=11') 25 ax.plot(X, y, 'b.', label='X') 26 plt.xlabel('X') 27 plt.ylabel('y') 28 ax.legend(loc='best', frameon=False) 29 plt.savefig('regu-4-overfitting.png', dpi=200)
輸出為:
{'coef': array([[ 0.7900162 , 26.72083627, 4.33062978, -7.65908434, 24.62696711, 12.33754429, -15.72302536, -9.54076366, 1.42221981, 1.74521649, 0.27877112]]), 'd': 11, 'intercept': array([-0.95562816]), 'loss': 3.1080267005676934e-23}
畫出的函數圖像如下:
圖2-1:degree=11時的函數圖像
由圖2-1可以看到,此時函數圖像穿過了每一個樣本點,所有的訓練樣本都落在了擬合的曲線上,訓練誤差接近與0。 可以說是近乎完美的模型了。但是,這樣的曲線與我們最開始數據的來源(一個二次方程加上一些隨機誤差)差異非常大。如果從相同來源再取一些樣本點,使用該模型預測會出現非常大的誤差。類似這種訓練誤差非常小,但是新數據點的測試誤差非常大的情況,就叫做模型的過擬合。過擬合出現時,表示模型過於復雜,過多考慮了當前樣本的特殊情況以及噪音(模型學習到了當前訓練樣本非全局的特性),使得模型的泛化能力下降。
出現過擬合一般有以下幾種解決方式:
- 降低模型復雜度,例如減小上面例子中的degree;
- 降維,減小特征的數量;
- 增加訓練樣本;
- 添加正則化項.
防止模型過擬合是機器學習領域里最重要的問題之一。鑒於該問題的普遍性和重要性,在滿足要求的情況下,能選擇簡單模型時應該盡量選擇簡單的模型。
Reference
http://scikit-learn.org/stable/modules/linear_model.html
Géron A. Hands-on machine learning with Scikit-Learn and TensorFlow: concepts, tools, and techniques to build intelligent systems[M]. " O'Reilly Media, Inc.", 2017. github
https://www.arxiv-vanity.com/papers/1803.09820/