邏輯回歸是一個分類器,其基本思想可以概括為:對於一個二分類(0~1)問題,若P(Y=1/X)>0.5則歸為1類,若P(Y=1/X)<0.5,則歸為0類。
一、模型概述
1、Sigmoid函數
為了具象化前文的基本思想,這里介紹Sigmoid函數:
函數圖像如下:
紅色的線條,即x=0處將Sigmoid曲線分成了兩部分:當 x < 0,y < 0.5 ;
當x > 0時,y > 0.5 。
實際分類問題中,往往根據多個預測變量來對響應變量進行分類。因此Sigmoid函數要與一個多元線性函數進行復合,才能應用於邏輯回歸。
2、邏輯斯諦模型
其中θx=θ1x1+θ2x2+……+θnxn 是一個多元線性模型。
上式可轉化為:
公式左側稱為發生比(odd)。當p(X)接近於0時,發生比就趨近於0;當p(X)接近於1時,發生比就趨近於∞。
兩邊取對數有:
公式左側稱為對數發生比(log-odd)或分對數(logit),上式就變成了一個線性模型。
不過相對於最小二乘擬合,極大似然法有更好的統計性質。邏輯回歸一般用極大似然法來擬合,擬合過程這里略過,下面只介紹如何用R應用邏輯回歸算法。
二、邏輯回歸應用
1、數據集
應用ISLR
包里的Smarket
數據集。先來看一下數據集的結構:
> summary(Smarket) Year Lag1 Lag2 Min. :2001 Min. :-4.922000 Min. :-4.922000 1st Qu.:2002 1st Qu.:-0.639500 1st Qu.:-0.639500 Median :2003 Median : 0.039000 Median : 0.039000 Mean :2003 Mean : 0.003834 Mean : 0.003919 3rd Qu.:2004 3rd Qu.: 0.596750 3rd Qu.: 0.596750 Max. :2005 Max. : 5.733000 Max. : 5.733000 Lag3 Lag4 Lag5 Min. :-4.922000 Min. :-4.922000 Min. :-4.92200 1st Qu.:-0.640000 1st Qu.:-0.640000 1st Qu.:-0.64000 Median : 0.038500 Median : 0.038500 Median : 0.03850 Mean : 0.001716 Mean : 0.001636 Mean : 0.00561 3rd Qu.: 0.596750 3rd Qu.: 0.596750 3rd Qu.: 0.59700 Max. : 5.733000 Max. : 5.733000 Max. : 5.73300 Volume Today Direction Min. :0.3561 Min. :-4.922000 Down:602 1st Qu.:1.2574 1st Qu.:-0.639500 Up :648 Median :1.4229 Median : 0.038500 Mean :1.4783 Mean : 0.003138 3rd Qu.:1.6417 3rd Qu.: 0.596750 Max. :3.1525 Max. : 5.733000
Smarket
是2001年到2005年間1250天的股票投資回報率數據,Year
是年份,Lag1
~Lag5
分別指過去5個交易日的投資回報率,Today
是今日投資回報率,Direction
是市場走勢,或Up
(漲)或Down
(跌)。
先看一下各變量的相關系數:
library(corrplot) corrplot(corr = cor(Smarket[,-9]),order = "AOE",type = "upper",tl.pos = "d") corrplot(corr = cor(Smarket[,-9]),add=TRUE,type = "lower",method = "number",order = "AOE",diag = FALSE,tl.pos = "n",cl.pos = "n")
可見除了Volume
和Year
之間相關系數比較大,說明交易量基本隨年份在增長,其他變量間基本沒多大的相關性。說明股票的歷史數據與未來的數據相關性很小,利用監督式學習方法很難准確預測未來股市的情況,這也是符合常識的。不過作為算法的應用教程,我們還是試一下。
2、訓練並測試邏輯回歸模型
邏輯回歸模型是廣義線性回歸模型的一種,因此函數是glm()
,但必須加上參數family=binomial
。
> attach(Smarket) > # 2005年前的數據用作訓練集,2005年的數據用作測試集 > train = Year<2005 > # 對訓練集構建邏輯斯諦模型 > glm.fit=glm(Direction~Lag1+Lag2+Lag3+Lag4+Lag5+Volume, + data=Smarket,family=binomial, subset=train) > # 對訓練好的模型在測試集中進行預測,type="response"表示只返回概率值 > glm.probs=predict(glm.fit,newdata=Smarket[!train,],type="response") > # 根據概率值進行漲跌分類 > glm.pred=ifelse(glm.probs >0.5,"Up","Down") > # 2005年實際的漲跌狀況 > Direction.2005=Smarket$Direction[!train] > # 預測值和實際值作對比 > table(glm.pred,Direction.2005) Direction.2005 glm.pred Down Up Down 77 97 Up 34 44 > # 求預測的准確率 > mean(glm.pred==Direction.2005) [1] 0.4801587
預測准確率只有0.48,還不如瞎猜。下面嘗試着調整模型。
#檢查一下模型概況 > summary(glm.fit) Call: glm(formula = Direction ~ Lag1 + Lag2 + Lag3 + Lag4 + Lag5 + Volume, family = binomial, data = Smarket, subset = train) Deviance Residuals: Min 1Q Median 3Q Max -1.302 -1.190 1.079 1.160 1.350 Coefficients: Estimate Std. Error z value Pr(>|z|) (Intercept) 0.191213 0.333690 0.573 0.567 Lag1 -0.054178 0.051785 -1.046 0.295 Lag2 -0.045805 0.051797 -0.884 0.377 Lag3 0.007200 0.051644 0.139 0.889 Lag4 0.006441 0.051706 0.125 0.901 Lag5 -0.004223 0.051138 -0.083 0.934 Volume -0.116257 0.239618 -0.485 0.628 (Dispersion parameter for binomial family taken to be 1) Null deviance: 1383.3 on 997 degrees of freedom Residual deviance: 1381.1 on 991 degrees of freedom AIC: 1395.1 Number of Fisher Scoring iterations: 3
可以發現所有變量的p值都比較大,都不顯著。前面線性回歸章節中提到AIC越小,模型越優,這里的AIC還是比較大的。
加入與響應變量無關的預測變量會造成測試錯誤率的增大(因為這樣的預測變量會增大模型方差,但不會相應地降低模型偏差),所以去除這樣的預測變量可能會優化模型。
上面模型中Lag1和Lag2的p值明顯比其他變量要小,因此只選這兩個變量再次進行訓練。
> glm.fit=glm(Direction~Lag1+Lag2, + data=Smarket,family=binomial, subset=train) > glm.probs=predict(glm.fit,newdata=Smarket[!train,],type="response") > glm.pred=ifelse(glm.probs >0.5,"Up","Down") > table(glm.pred,Direction.2005) Direction.2005 glm.pred Down Up Down 35 35 Up 76 106 > mean(glm.pred==Direction.2005) [1] 0.5595238 > 106/(76+106) [1] 0.5824176
這次模型的總體准確率達到了56%,總算說明統計模型的預測准確度比瞎猜要好(雖然只有一點點)。根據混淆矩陣,當邏輯回歸模型預測下跌時,有50%的准確率;當邏輯回歸模型預測上漲時,有58%的准確率。(矩陣的行名表預測值,列名表實際值)
應用這個模型來預測2組新的數據:
> predict(glm.fit,newdata = data.frame(Lag1=c(1.2,1.5),Lag2=c(1.1,-0.8)),type="response") 1 2 0.4791462 0.4960939
可見對於(Lag1,Lag2)=(1.2,1.1)和(1.5,-0.8)這兩點來說,模型預測的都是股票會跌。需要注意的是,邏輯回歸的預測結果並不能像線性回歸一樣提供置信區間(或預測區間),因此加上interval參數也沒用。
參考文獻
[1]R語言數據分析系列之九 - 邏輯回歸
[2] Gareth James et al. An Introduction to Statistical Learning.