吳裕雄 python 機器學習——線性判斷分析LinearDiscriminantAnalysis


import numpy as np
import matplotlib.pyplot as plt

from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
from sklearn.model_selection import train_test_split
from sklearn import datasets, linear_model,discriminant_analysis

def load_data():
    # 使用 scikit-learn 自帶的 iris 數據集
    iris=datasets.load_iris()
    X_train=iris.data
    y_train=iris.target
    return train_test_split(X_train, y_train,test_size=0.25,random_state=0,stratify=y_train)

#線性判斷分析LinearDiscriminantAnalysis
def test_LinearDiscriminantAnalysis(*data):
    X_train,X_test,y_train,y_test=data
    lda = discriminant_analysis.LinearDiscriminantAnalysis()
    lda.fit(X_train, y_train)
    print('Coefficients:%s, intercept %s'%(lda.coef_,lda.intercept_))
    print('Score: %.2f' % lda.score(X_test, y_test))
    
# 產生用於分類的數據集
X_train,X_test,y_train,y_test=load_data()
# 調用 test_LinearDiscriminantAnalysis
test_LinearDiscriminantAnalysis(X_train,X_test,y_train,y_test)

def plot_LDA(converted_X,y):
    '''
    繪制經過 LDA 轉換后的數據
    :param converted_X: 經過 LDA轉換后的樣本集
    :param y: 樣本集的標記
    '''
    fig=plt.figure()
    ax=Axes3D(fig)
    colors='rgb'
    markers='o*s'
    for target,color,marker in zip([0,1,2],colors,markers):
        pos=(y==target).ravel()
        X=converted_X[pos,:]
        ax.scatter(X[:,0], X[:,1], X[:,2],color=color,marker=marker,label="Label %d"%target)
    ax.legend(loc="best")
    fig.suptitle("Iris After LDA")
    plt.show()
    
def run_plot_LDA():
    '''
    執行 plot_LDA 。其中數據集來自於 load_data() 函數
    '''
    X_train,X_test,y_train,y_test=load_data()
    X=np.vstack((X_train,X_test))
    Y=np.vstack((y_train.reshape(y_train.size,1),y_test.reshape(y_test.size,1)))
    lda = discriminant_analysis.LinearDiscriminantAnalysis()
    lda.fit(X, Y)
    converted_X=np.dot(X,np.transpose(lda.coef_))+lda.intercept_
    plot_LDA(converted_X,Y)
    
# 調用 run_plot_LDA
run_plot_LDA()

def test_LinearDiscriminantAnalysis_solver(*data):
    '''
    測試 LinearDiscriminantAnalysis 的預測性能隨 solver 參數的影響
    '''
    X_train,X_test,y_train,y_test=data
    solvers=['svd','lsqr','eigen']
    for solver in solvers:
        if(solver=='svd'):
            lda = discriminant_analysis.LinearDiscriminantAnalysis(solver=solver)
        else:
            lda = discriminant_analysis.LinearDiscriminantAnalysis(solver=solver,shrinkage=None)
        lda.fit(X_train, y_train)
        print('Score at solver=%s: %.2f' %(solver, lda.score(X_test, y_test)))
        
# 調用 test_LinearDiscriminantAnalysis_solver
test_LinearDiscriminantAnalysis_solver(X_train,X_test,y_train,y_test)

def test_LinearDiscriminantAnalysis_shrinkage(*data):
    '''
    測試  LinearDiscriminantAnalysis 的預測性能隨 shrinkage 參數的影響
    '''
    X_train,X_test,y_train,y_test=data
    shrinkages=np.linspace(0.0,1.0,num=20)
    scores=[]
    for shrinkage in shrinkages:
        lda = discriminant_analysis.LinearDiscriminantAnalysis(solver='lsqr',shrinkage=shrinkage)
        lda.fit(X_train, y_train)
        scores.append(lda.score(X_test, y_test))
    ## 繪圖
    fig=plt.figure()
    ax=fig.add_subplot(1,1,1)
    ax.plot(shrinkages,scores)
    ax.set_xlabel(r"shrinkage")
    ax.set_ylabel(r"score")
    ax.set_ylim(0,1.05)
    ax.set_title("LinearDiscriminantAnalysis")
    plt.show()
# 調用 test_LinearDiscr
test_LinearDiscriminantAnalysis_shrinkage(X_train,X_test,y_train,y_test)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM