一、基礎理解
- 問題:邏輯回歸算法是用回歸的方式解決分類的問題,而且只可以解決二分類問題;
- 方案:可以通過改造,使得邏輯回歸算法可以解決多分類問題;
- 改造方法:
- OvR(One vs Rest),一對剩余的意思,有時候也稱它為 OvA(One vs All);一般使用 OvR,更標准;
- OvO(One vs One),一對一的意思;
- 改造方法不是指針對邏輯回歸算法,而是在機器學習領域有通用性,所有二分類的機器學習算法都可使用此方法進行改造,解決多分類問題;
二、原理
1)OvR
- 思想:n 種類型的樣本進行分類時,分別取一種樣本作為一類,將剩余的所有類型的樣本看做另一類,這樣就形成了 n 個二分類問題,使用邏輯回歸算法對 n 個數據集訓練出 n 個模型,將待預測的樣本傳入這 n 個模型中,所得概率最高的那個模型對應的樣本類型即認為是該預測樣本的類型;

- 時間復雜度:如果處理一個二分類問題用時 T,此方法需要用時 n.T;
2)OvO
- 思想: n 類樣本中,每次挑出 2 種類型,兩兩結合,一共有 Cn2 種二分類情況,使用 Cn2 種模型預測樣本類型,有 Cn2 個預測結果,種類最多的那種樣本類型,就認為是該樣本最終的預測類型;

- 時間復雜度:如果處理一個二分類問題用時 T,此方法需要用時 Cn2 .T = [n.(n - 1) / 2] . T;
3)區別
- OvO 用時較多,但其分類結果更准確,因為每一次二分類時都用真實的類型進行比較,沒有混淆其它的類別;
三、scikit-learn 中的邏輯回歸
- scikit-learn的LogisticRegression 算法內包含了:正則化、優化損失函數的方法、多分類方法等;
-
LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True, intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1, penalty='l2', random_state=None, solver='liblinear', tol=0.0001, verbose=0, warm_start=False)
# LogisticRegression() 實例對象,包含了很多參數;
-
不懂的要學會 看文檔、看文檔、看文檔:help(算法、實例對象);
-
C=1.0:正則化的超參數,默認為 1.0;
- multi_class='ovr':scikit-learn中的邏輯回歸默認支持多分類問題,分類方式為 'OvR';
- solver='liblinear'、'lbfgs'、'sag'、'newton-cg':scikit-learn中優化損失函數的方法,不是梯度下降法;
- 多分類中使用 multinomial (OvO)時,只能使用 'lbfgs'、'sag'、'newton-cg' 來優化損失函數;
- 當損失函數使用了 L2 正則項時,優化方法只能使用 'lbfgs'、'sag'、'newton-cg';
- 使用 'liblinear' 優化損失函數時,正則項可以為 L1 和 L2 ;
1)例(3 種樣本類型):LogisticRegression() 默認使用 OvR
-
import numpy as np import matplotlib.pyplot as plt from sklearn import datasets iris = datasets.load_iris() # [:, :2]:所有行,0、1 列,不包含 2 列; X = iris.data[:,:2] y = iris.target from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666) from sklearn.linear_model import LogisticRegression log_reg = LogisticRegression() log_reg.fit(X_train, y_train) log_reg.score(X_test, y_test) # 准確率:0.6578947368421053
- 繪制決策邊界
def plot_decision_boundary(model, axis): x0, x1 = np.meshgrid( np.linspace(axis[0], axis[1], int((axis[1]-axis[0])*100)).reshape(-1,1), np.linspace(axis[2], axis[3], int((axis[3]-axis[2])*100)).reshape(-1,1) ) X_new = np.c_[x0.ravel(), x1.ravel()] y_predict = model.predict(X_new) zz = y_predict.reshape(x0.shape) from matplotlib.colors import ListedColormap custom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9']) plt.contourf(x0, x1, zz, linewidth=5, cmap=custom_cmap) plot_decision_boundary(log_reg, axis=[4, 8.5, 1.5, 4.5]) # 可視化時只能在同一個二維平面內體現兩種特征; plt.scatter(X[y==0, 0], X[y==0, 1]) plt.scatter(X[y==1, 0], X[y==1, 1]) plt.scatter(X[y==2, 0], X[y==2, 1]) plt.show()

2)使用 OvO 分類
-
log_reg2 = LogisticRegression(multi_class='multinomial', solver='newton-cg') # 'multinomial':指 OvO 方法; log_reg2.fit(X_train, y_train) log_reg2.score(X_test, y_test) # 准確率:0.7894736842105263 plot_decision_boundary(log_reg2, axis=[4, 8.5, 1.5, 4.5]) plt.scatter(X[y==0, 0], X[y==0, 1]) plt.scatter(X[y==1, 0], X[y==1, 1]) plt.scatter(X[y==2, 0], X[y==2, 1]) plt.show()

3)使用所有分類數據
- OvR
X = iris.data y = iris.target X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666) log_reg_ovr = LogisticRegression() log_reg_ovr.fit(X_train, y_train) log_reg_ovr.score(X_test, y_test) # 准確率:0.9473684210526315
- OvO
log_reg_ovo = LogisticRegression(multi_class='multinomial', solver='newton-cg') log_reg_ovo.fit(X_train, y_train) log_reg_ovo.score(X_test, y_test) # 准確率:1.0
4)分析
- 通過准確率對比可以看出,使用 OvO 方法改造 LogisticRegression() 算法,得到的模型准確率較高;
四、OvR 和 OvO 的封裝
- scikit-learn單獨封裝了實現 OvO 和 OvR 的類,使得任意二分類算法都可以通過使用這兩個類解決多分類問題;
1)OvR 的封裝
- 模塊
from sklearn.multiclass import OneVsRestClassifier
- 使用方法
- ovr = OneVsRestClassifier(二分類算法的實例對象):得到一個可以解決多分類的實例對象;
- ovr.fit(X_train, y_train):擬合多分類實例對象;
- 例
from sklearn.multiclass import OneVsRestClassifier ovr = OneVsRestClassifier(log_reg) ovr.fit(X_train, y_train) ovr.score(X_test, y_test) # 准確率:0.9473684210526315
2)OvO 的封裝
- 模塊
from sklearn.multiclass import OneVsOneClassifier
- 使用方法:同理 OvR;
- 例
from sklearn.multiclass import OneVsOneClassifier ovo = OneVsOneClassifier(log_reg) ovo.fit(X_train, y_train) ovo.score(X_test, y_test) # 准確率:1.0
