一、簡介
邏輯回歸(Logistic Regression),與它的名字恰恰相反,它是一個分類器而非回歸方法,在一些文獻里它也被稱為logit回歸、最大熵分類器(MaxEnt)、對數線性分類器等;我們都知道可以用回歸模型來進行回歸任務,但如果要利用回歸模型來進行分類該怎么辦呢?本文介紹的邏輯回歸就基於廣義線性模型(generalized linear model),下面我們簡單介紹一下廣義線性模型:
我們都知道普通線性回歸模型的形式:
如果等號右邊的輸出值與左邊y經過某個函數變換后得到的值比較貼切,如下面常見的“對數線性回歸”(log-linear regression):
這里對數函數ln(y)起到的作用便是將y轉換為其對數值,且這個對數值與右邊的線性模型的預測值更為貼切接近,我們管類似這里對數函數的套在y外面的單調可微函數(因為只有單調可微函數才存在反函數)叫做“聯系函數”(link function),引出下面更一般的形式:
我們在這里使用一個單調可微函數將分類任務的真實標記y與線性回歸模型的預測值聯系起來;
考慮二分類任務,其輸出標記:
而線性回歸模型產出的預測值:
是連續域上的實值,因此我們需要把實值z轉換為0/1值,最理想的是“單位階躍函數”(unit-step function)
這里規定預測值z大於零就判為正例,小於零則判為反例,預測值為臨界值零則可任意判別(事實上這種情況的樣本本就存在一些問題而無法通過邏輯回歸進行分類),下圖展示了單位階躍函數(紅色)與對數幾率函數(黑色):
從上圖可以看出,單位階躍函數不連續,數學性質差,因此不能直接用作廣義線性模型中的link function,於是我們的目的是找到在一定程度上近似單位階躍函數的“替代函數”(surrogate function),並希望它單調可微。對數幾率函數(logistic function)正是這樣一個常用的替代函數:
對數幾率函數是一種“Sigmoid”函數(即形似S的函數,在神經網絡的激勵函數中有廣泛應用),它將z值轉化為一個接近0或1的y值,並且其輸出值在z=0附近變化很陡。將該對數幾率函數作為聯系函數代入廣義線性模型,可得:
我們對其進行如下推導變換:
若將y視為樣本x作為正例的可能性,則1-y是其反例可能性,兩者的比值:
稱為“幾率”(odds),反映了x作為正例的相對可能性。對幾率取對數則得到“對數幾率”(log odds,亦稱logit):
由此可看出,這實際上是用線性回歸模型的預測結果去逼近真實標記的對數幾率,因此其對應的模型稱為“對數幾率回歸”(logistic regression,亦稱logit regression),這種方法具有諸多優點:
1.直接針對分類可能性進行建模,無需事先假設數據分布,這樣就避免了假設分布不准確所帶來的問題;
2.不僅輸出預測類別,還輸出了近似的預測概率,這對許多需要利用預測概率進行輔助決策的任務很有用;
3.對率函數是任意階可導的凸函數,有很好的數學性質,現有的許多數值優化算法都可直接用於求取最優解
二、訓練方法
根據一個樣本集訓練邏輯回歸模型,實際上是要得到參數w與截距b,下面我們來仔細推導訓練的思想:
前面我們得到了:
將其中的y視為類后驗概率估計:
則前面的式子可改寫為:
下面根據上式對正例和反例的后驗概率估計進行推導:
因此,我們可以通過“極大似然法”(maximum likelihood method)來估計w與b,給定數據集:
對率回歸模型最大化“對數似然”(log-likelihood):
即令每個樣本屬於其真實標記的概率越大越好。令:
則:
再令:
則:
則最大化“對數似然”轉換為:
因為上式為關於β的高階可導連續凸函數,由凸優化理論,使用經典的數值優化算法如梯度下降法(gradient decent method)、牛頓法(Newton method)等均可求得其最優解,即得到:
則我們的邏輯回歸模型訓練完成。
三、Python實現
我們使用sklearn.linear_model中的LogisticRegression方法來訓練邏輯回歸分類器,其主要參數如下:
class_weight:用於處理類別不平衡問題,即這時的閾值不再是0.5,而是一個再縮放后的值;
fit_intercept:bool型參數,設置是否求解截距項,即b,默認True;
random_state:設置隨機數種子;
solver:選擇用於求解最大化“對數似然”的算法,有以下幾種及其適用場景:
1.對於較小的數據集,使用"liblinear"更佳;
2.對於較大的數據集,"sag"、"saga"更佳;
3.對於多分類問題,應使用"newton-cg"、"sag"、"saga"、"lbfgs";
max_iter:設置求解算法的迭代次數,僅適用於solver設置為"newton-cg"、"lbfgs"、"sag"的情況;
multi_class:為多分類問題選擇訓練策略,有"ovr"、"multinomial" ,后者不支持"liblinear";
n_jobs:當處理多分類問題訓練策略為'ovr'時,在訓練時並行運算使用的CPU核心數量。當solver被設置為“liblinear”時,不管是否指定了multi_class,這個參數都會被忽略。如果給定值-1,則所有的核心都被使用,所以推薦-1,默認項為1,即只使用1個核心。
下面我們以威斯康辛州乳腺癌數據為例進行演示:
from sklearn import datasets from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression from sklearn.metrics import f1_score as f1 from sklearn.metrics import recall_score as recall from sklearn.metrics import confusion_matrix as cm '''導入威斯康辛州乳腺癌數據''' X,y = datasets.load_breast_cancer(return_X_y=True) '''分割訓練集與驗證集''' X_train,X_test,y_train,y_test = train_test_split(X,y,train_size=0.7,test_size=0.3) '''初始化邏輯回歸分類器,這里對類別不平衡問題做了處理''' cl = LogisticRegression(class_weight='balanced') '''利用訓練數據進行邏輯回歸分類器的訓練''' cl = cl.fit(X_train,y_train) '''打印訓練的模型在驗證集上的正確率''' print('邏輯回歸的測試准確率:'+str(cl.score(X_test,y_test))+'\n') '''打印f1得分''' print('F1得分:'+str(f1(y_test,cl.predict(X_test)))+'\n') '''打印召回得分''' print('召回得分(越接近1越好):'+str(recall(y_test,cl.predict(X_test)))+'\n') '''打印混淆矩陣''' print('混淆矩陣:'+'\n'+str(cm(y_test,cl.predict(X_test)))+'\n')
四、R實現
在R中實現邏輯回歸的過程比較細致,也比較貼近於統計學思想,我們使用glm()來訓練邏輯回歸模型,這是一個訓練廣義線性模型的函數,注意,這種方法不像sklearn中那樣主要在乎的是輸出的分類結果,而是更加注重模型的思想以及可解釋性(即每個變量對結果的影響程度),下面對glm()的主要參數進行介紹:
formula:這里和R中常見的算法格式一樣,傳遞一個因變量~自變量的形式;
family:這個參數可以傳遞一個字符串或family函數形式的輸入,默認為gaussian,表示擬合出的函數的誤差項服從正態分布,若使用family則可同時定義誤差服從的分布和廣義線性模型中的聯系函數,例如本文所需的邏輯回歸函數,就可以有兩種設定方式:
1.傳入gaussian
2.傳入binomial(link='logit')
data:指定變量所屬的數據框名稱;
weights:傳入一個numeric型向量,用於類別不平衡問題的再縮放,默認無,即將1與0類視作平衡;
model:邏輯型變量,用於控制是否輸出最終訓練的模型;
下面我們對威斯康辛州乳腺癌數據集進行邏輯回歸分類訓練,該數據集下載自https://archive.ics.uci.edu/ml/datasets.html
> rm(list=ls()) > setwd('C:\\Users\\windows\\Desktop')> > #read data > data <- read.table('breast.csv',sep=',')[,-1] > data[,1] = as.numeric(data[,1])-1 > > #spilt the datasets into train-dataset and test-dataset with a proportion of 7:3 > sam <- sample(1:dim(data)[1],0.7*dim(data)[1]) > train <- data[sam,] > test <- data[-sam,] > > #method 1 to train the logistic regression model > cl1 <- glm(V2~.,data=train,family=gaussian,model = T) > summary(cl1) Call: glm(formula = V2 ~ ., family = gaussian, data = train, model = T) Deviance Residuals: Min 1Q Median 3Q Max -0.5508 -0.1495 -0.0289 0.1313 0.8827 Coefficients: Estimate Std. Error t value Pr(>|t|) (Intercept) -1.5799840 0.5155747 -3.065 0.002342 ** V3 -0.2599482 0.2154418 -1.207 0.228370 V4 -0.0079273 0.0095174 -0.833 0.405427 V5 0.0081528 0.0305549 0.267 0.789754 V6 0.0013896 0.0006602 2.105 0.035995 * V7 1.4228253 2.5709862 0.553 0.580315 V8 -3.6635584 1.5680864 -2.336 0.020012 * V9 0.7080702 1.2295130 0.576 0.565039 V10 3.1537615 2.3451202 1.345 0.179514 V11 0.1732304 0.8911542 0.194 0.845979 V12 -1.9387489 6.7295898 -0.288 0.773438 V13 0.7726866 0.4147097 1.863 0.063233 . V14 -0.0279811 0.0410619 -0.681 0.496025 V15 -0.1233500 0.0552559 -2.232 0.026195 * V16 0.0004280 0.0017295 0.247 0.804693 V17 12.4604521 9.3370259 1.335 0.182861 V18 3.4109051 2.5982124 1.313 0.190074 V19 -3.2029680 1.4955409 -2.142 0.032877 * V20 5.8114764 6.7403058 0.862 0.389142 V21 6.3245471 3.5132301 1.800 0.072649 . V22 -3.8548911 13.9312795 -0.277 0.782160 V23 0.2085364 0.0725441 2.875 0.004281 ** V24 0.0163053 0.0082144 1.985 0.047892 * V25 0.0103947 0.0073762 1.409 0.159613 V26 -0.0015405 0.0004029 -3.823 0.000155 *** V27 0.3199127 1.7892352 0.179 0.858195 V28 -0.2059715 0.5003384 -0.412 0.680826 V29 0.3228923 0.3129524 1.032 0.302863 V30 0.4349610 1.1458239 0.380 0.704458 V31 0.0973530 0.6204197 0.157 0.875398 V32 3.6376056 2.9501115 1.233 0.218350 --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 (Dispersion parameter for gaussian family taken to be 0.0537472) Null deviance: 91.048 on 397 degrees of freedom Residual deviance: 19.725 on 367 degrees of freedom AIC: -2.3374 Number of Fisher Scoring iterations: 2 > pre <- predict(cl1,test[,2:dim(test)[2]]) > predict <- data.frame(true=test[,1],predict=ifelse(pre > 0.5,1,0)) > #print the confusion matrix > table(predict) predict true 0 1 0 99 1 1 10 61 > #print the accuracy > cat('Accuracy:',sum(diag(prop.table(table(predict)))),'\n') Accuracy: 0.9356725 > > #method 2 to train the logistic regression model > cl2 <- glm(V2~.,data=train,family=binomial(link='logit'),model = T) Warning messages: 1: glm.fit:算法沒有聚合 2: glm.fit:擬合機率算出來是數值零或一 > summary(cl2) Call: glm(formula = V2 ~ ., family = binomial(link = "logit"), data = train, model = T) Deviance Residuals: Min 1Q Median 3Q Max -2.220e-04 -2.100e-08 -2.100e-08 2.100e-08 2.033e-04 Coefficients: Estimate Std. Error z value Pr(>|z|) (Intercept) -8.758e+02 8.910e+05 -0.001 0.999 V3 -2.024e+02 1.990e+05 -0.001 0.999 V4 3.032e+00 9.058e+03 0.000 1.000 V5 4.306e+01 2.243e+04 0.002 0.998 V6 2.976e-01 8.733e+02 0.000 1.000 V7 5.400e+03 1.824e+06 0.003 0.998 V8 -6.629e+03 9.677e+05 -0.007 0.995 V9 4.722e+03 1.151e+06 0.004 0.997 V10 -1.708e+03 2.175e+06 -0.001 0.999 V11 -1.363e+03 6.115e+05 -0.002 0.998 V12 6.584e+03 7.757e+06 0.001 0.999 V13 -2.640e+01 6.282e+05 0.000 1.000 V14 -1.098e+02 7.760e+04 -0.001 0.999 V15 1.270e+02 1.013e+05 0.001 0.999 V16 4.578e+00 5.046e+03 0.001 0.999 V17 -1.474e+04 6.903e+06 -0.002 0.998 V18 1.068e+04 3.621e+06 0.003 0.998 V19 -5.032e+03 8.410e+05 -0.006 0.995 V20 4.981e+02 6.488e+06 0.000 1.000 V21 -1.025e+04 3.006e+06 -0.003 0.997 V22 -6.005e+04 2.600e+07 -0.002 0.998 V23 -2.170e+01 5.623e+04 0.000 1.000 V24 1.412e+01 1.030e+04 0.001 0.999 V25 -1.807e+01 7.861e+03 -0.002 0.998 V26 4.411e-01 4.646e+02 0.001 0.999 V27 4.996e+02 9.163e+05 0.001 1.000 V28 -4.239e+02 4.489e+05 -0.001 0.999 V29 1.159e+02 1.815e+05 0.001 0.999 V30 1.786e+03 1.018e+06 0.002 0.999 V31 2.349e+03 3.750e+05 0.006 0.995 V32 -8.331e+02 3.861e+06 0.000 1.000 (Dispersion parameter for binomial family taken to be 1) Null deviance: 5.1744e+02 on 397 degrees of freedom Residual deviance: 4.1036e-07 on 367 degrees of freedom AIC: 62 Number of Fisher Scoring iterations: 25 > pre <- predict(cl2,test[,2:dim(test)[2]]) > predict <- data.frame(true=test[,1],predict=ifelse(pre > 0.5,1,0)) > #print the confusion matrix > table(predict) predict true 0 1 0 94 6 1 7 64 > #print the accuracy > cat('Accuracy:',sum(diag(prop.table(table(predict)))),'\n') Accuracy: 0.9239766
可以看出,方法1效果更佳;
以上就是關於邏輯回歸的基本內容,今后也會不定時地在本文中增加更多內容,敬請期待。