編程實現線性判別分析,並給出西瓜數據集3.0α上的運行結果


1.題目理解

 

將西瓜數據集的樣例投影到一條直線上,使得好瓜、壞瓜各自的投影點盡可能接近,好瓜與壞瓜之間的投影點盡可能遠離。

 

2.算法原理

 

 

3.算法設計

 

①  根據LDA原理求解得到w,結合數據集得到LDA直線;

 

②  將每個樣本映射到LDA直線上,觀察分析結果。

 

4.關鍵代碼

 

 1 # 加載數據集
 2 dataset = np.loadtxt('C:/Users/86185/PycharmProjects/ML1/watermelon_3a.csv', delimiter=",")
 3 
 4 # 分離屬性值和標簽
 5 X = dataset[:,1:3]
 6 y = dataset[:,3]
 7 u = []
 8 for i in range(2):
 9     u.append(np.mean(X[y==i],axis=0))
10 
11 m,n = np.shape(X)
12 Sw = np.zeros((n,n))
13 for i in range(m):
14     x_temp = X[i].reshape(n, 1)     # 行向量變為列向量
15     if y[i]==0: u_temp = u[0].reshape(n, 1)
16     if y[i]==1: u_temp = u[1].reshape(n, 1)
17     Sw +=np.dot(x_temp-u_temp, (x_temp-u_temp).T)
18 
19 Sw = np.mat(Sw)
20 # print(Sw)
21 Sw_inv = np.linalg.inv (Sw)
22 # print(Sw_inv)
23 w = np.dot(Sw_inv, (u[0]-u[1]).reshape(n,1))
24 print(w)

 

先根據公式求得w

 1 def GetPoint(point0, w):
 2     k0 = w[1, 0]/w[0, 0]
 3     k1 = w[0, 0]/w[1, 0]
 4     x0 = point0[0]
 5     y0 = point0[1]
 6     x1 = (k0 * x0 - y0) / (k0 + k1)
 7     y1 = -k1 * x1
 8     return x0, x1, y0, y1
 9 
10 f1 = plt.figure('first')
11 plt.xlim( -0.2, 1 )         # 設定坐標軸的范圍
12 plt.ylim( -0.2, 0.6 )
13 
14 
15 x = np.arange(-1, 3)
16 yy = -(w[0,0]/w[1,0])*x

做LDA直線yy;GetPoint()函數用來計算點到直線yy的投影

 

5.結果展示

根據運行結果顯示,沒有很明確的將好瓜與壞瓜區分開來,好瓜與壞瓜的投影點不夠遠離,壞瓜與壞瓜之間的投影點不夠聚集。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2026 CODEPRJ.COM