線性判別分析之python代碼分析


前幾天主要更新了一下機器學習的相關理論,主要介紹了感知機,SVM以及線性判別分析。現在用代碼來實現一下其中的模型,一方面對存粹理論的理解,另一方面也提升一下代碼的能力。本文就先從線性判別分析開始講起,不熟悉的可以先移步至線性判別分析(Linear Discriminant Analysis, LDA) - ZhiboZhao - 博客園 (cnblogs.com)對基礎知識做一個大概的了解。在代碼分析過程中,本文重點從應用入手,只講API中最常用的參數,能夠完成任務即可。
本文代碼參考鏈接:https://github.com/han1057578619/MachineLearning_Zhouzhihua_ProblemSets

一、數據准備

數據集部分我采用周志華《機器學習》書中的 watermelon數據集,數據集前5行如下:

編號 色澤 根蒂 敲聲 紋理 臍部 觸感 密度 含糖率 好瓜
1 青綠 蜷縮 濁響 清晰 凹陷 硬滑 0.697 0.46
2 烏黑 蜷縮 沉悶 清晰 凹陷 硬滑 0.774 0.376
3 烏黑 蜷縮 濁響 清晰 凹陷 硬滑 0.634 0.264
4 青綠 蜷縮 沉悶 清晰 凹陷 硬滑 0.608 0.318
5 淺白 蜷縮 濁響 清晰 凹陷 硬滑 0.556 0.215

1.1 讀取數據:

import pandas as pd
data_path = './watermelon3_0_ch.csv'
data = pd.read_csv(data_path).values	# 讀取數據並轉為np.array類型

這里主要運用 pd.read_csv() 進行 .csv 文件的讀取,該模塊主要用到的參數如下:

pd.read_csv(file_path, sep, header)

其中:file_path 是目標文件的路徑;sep 是目標文件中的分隔符,默認 .csv 文件以 ‘,’ 分隔;header 是整數類型的,它的數值決定了讀取 .csv 文件時從第幾行開始。舉個例子:

# header = 0, 默認第0行為表頭,從表頭往下開始讀取
head_0 = pd.read_csv(data_path, header = 0)
# header = 1, 默認第1行為表頭,從表頭往下開始讀取
head_0 = pd.read_csv(data_path, header = 1)

header_0的結果為:

編號 色澤 根蒂 敲聲 紋理 臍部 觸感 密度 含糖率 好瓜
1 青綠 蜷縮 濁響 清晰 凹陷 硬滑 0.697 0.46
2 烏黑 蜷縮 沉悶 清晰 凹陷 硬滑 0.774 0.376

header_1的結果為:

1 青綠 蜷縮 濁響 清晰 凹陷 硬滑 0.697 0.46
2 烏黑 蜷縮 沉悶 清晰 凹陷 硬滑 0.774 0.376
3 烏黑 蜷縮 濁響 清晰 凹陷 硬滑 0.634 0.264

1.2 對數據進行 "one-hot" 編碼

我們以二維線性判別分析為例,只根據 "密度" 和 "含糖量" 來確定是否是好瓜

X = data[:, 7:9].astype(float)	# 提取密度和含糖量的數據作為輸入特征
y = data[:, 9]	# 提取最后一列作為判別類型

y[y == '是'] = 1	# 需要進行one-hot編碼,將瓜分類
y[y == '否'] = 0
y = y.astype(int)

'''
以好瓜/壞瓜 來對樣本進行分類
'''
pos = y == 1, neg = y == 0 	# 分別找到正負樣本的位置
X0 = X[neg], X1 = X[pos]   # 以提取正負樣本的輸入特征

二、線性判別分析

2.1 根據對應模型進行求解

從上一講中我們得到,線性分類判別模型的最優解為:

\[w = S_{w}^{-1}(u_{0}-u_{1}) \]

其中,

\[u_{0} = \dfrac{1}{m} \sum_{i=1}^{m}x_{i},\quad u_{1} = \dfrac{1}{n} \sum_{i=1}^{n}x_{i}\\ S_{w} = \dfrac{1}{m} \sum_{i=1}^{m}(x_{i}-u_{0})(x_{i}-u_{0})^{T} +\dfrac{1}{n} \sum_{i=1}^{n}(x_{i}-u_{1})(x_{i}-u_{1})^{T}\\ \]

這里面注意一點,為了更符合人的理解習慣,我們在公式 (3) 中,定義的 \(S_w\) 是單個向量相乘之后求和;但是矩陣形式則更方便被計算機描述,設 $ X_{0} = {x_{1},x_{2},...,x_{m} }^{T}, X_{1} = {x_{1},x_{2},...,x_{n} }^{T}$,由於 \(x_{i} \in R^{p \times 1}\),因此\(X_{0}, X_{1} \in R^{m \times p}\),改寫成矩陣形式:

\[S_{w} = \dfrac{1}{m} (X_{0}-u_{0})^{T}(X_{0}-u_{0}) + \dfrac{1}{n}(X_{1}-u_{1})^{T}(X_{1}-u_{1}) \]

於是,對應代碼為:

u0 = X0.mean(0, keepdims=True)  # (1, p)
u1 = X1.mean(0, keepdims=True)

sw = np.dot((X0 - u0).T, X0 - u0) + np.dot((X1 - u1).T, X1 - u1)
w = np.dot(np.linalg.inv(sw), (u0 - u1).T).reshape(1, -1)  # (1, p)

說明:

mean() 函數在指定維度上求均值,由於 \(X_{0} \in R^{m \times p}\),所有指定維度為0之后相當於對所有 \(m\) 個樣本進行求平均,得到 \(u_{0} \in R^{1\times p}\)

2.2 模型可視化

這一部分代碼主要是繪圖的一些格式,本文就不多做解釋了。

fig, ax = plt.subplots()
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
ax.spines['left'].set_position(('data', 0))
ax.spines['bottom'].set_position(('data', 0))

plt.scatter(X1[:, 0], X1[:, 1], c='k', marker='o', label='good')
plt.scatter(X0[:, 0], X0[:, 1], c='r', marker='x', label='bad')

plt.xlabel('密度', labelpad=1)
plt.ylabel('含糖量')
plt.legend(loc='upper right')

x_tmp = np.linspace(-0.05, 0.15)
y_tmp = x_tmp * w[0, 1] / w[0, 0]
plt.plot(x_tmp, y_tmp, '#808080', linewidth=1)

wu = w / np.linalg.norm(w)

# 正負樣板店
X0_project = np.dot(X0, np.dot(wu.T, wu))
plt.scatter(X0_project[:, 0], X0_project[:, 1], c='r', s=15)
for i in range(X0.shape[0]):
plt.plot([X0[i, 0], X0_project[i, 0]], [X0[i, 1], X0_project[i, 1]], '--r', linewidth=1)

X1_project = np.dot(X1, np.dot(wu.T, wu))
plt.scatter(X1_project[:, 0], X1_project[:, 1], c='k', s=15)
for i in range(X1.shape[0]):
plt.plot([X1[i, 0], X1_project[i, 0]], [X1[i, 1], X1_project[i, 1]], '--k', linewidth=1)

# 中心點的投影
u0_project = np.dot(u0, np.dot(wu.T, wu))
plt.scatter(u0_project[:, 0], u0_project[:, 1], c='#FF4500', s=60)
u1_project = np.dot(u1, np.dot(wu.T, wu))
plt.scatter(u1_project[:, 0], u1_project[:, 1], c='#696969', s=60)

ax.annotate(r'u0 投影點',
xy=(u0_project[:, 0], u0_project[:, 1]),
xytext=(u0_project[:, 0] - 0.2, u0_project[:, 1] - 0.1),
size=13,
va="center", ha="left",
arrowprops=dict(arrowstyle="->",
color="k",
)
)

ax.annotate(r'u1 投影點',
xy=(u1_project[:, 0], u1_project[:, 1]),
xytext=(u1_project[:, 0] - 0.1, u1_project[:, 1] + 0.1),
size=13,
va="center", ha="left",
arrowprops=dict(arrowstyle="->",
color="k",
)
)
plt.axis("equal")  # 兩坐標軸的單位刻度長度保存一致
plt.show()

self.w = w
self.u0 = u0
self.u1 = u1
return self

最終得到的分類結果圖如下:


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM