一、噪音
- 噪音產生的因素:可能是測量儀器的誤差、也可能是人為誤差、或者測試方法有問題等;
- 降噪作用:方便數據的可視化,使用樣本特征更清晰;便於算法操作數據;
- 具體操作:從 n 維降到 k 維,再講降維后的數據集升到 n 維,得到的新的數據集為去燥后的數據集;
- 降維:X_reduction = pca.transform ( X )
- 升維:X_restore = pca.inverse_transform ( X_reduction ),數據集 X_restore 為去燥后的數據集;
二、實例
1)例一
-
模擬並繪制樣本信息
import numpy as np import matplotlib.pyplot as plt X = np.empty((100, 2)) X[:, 0] = np.random.uniform(0., 100, size=100) X[:, 1] = 0.75 * X[:, 0] + 3. + np.random.normal(0, 5, size=100) plt.scatter(X[:, 0], X[:, 1]) plt.show()
- 實際上,樣本的狀態看似在直線上下抖動式的分布,其實抖動的距離就是噪音;
-
使用 PCA 降維,達到降噪的效果
- 操作:數據降維后,再升到原來維度;
- inverse_transform(低維數據):將低維數據升為高維數據
-
from sklearn.decomposition import PCA pca = PCA(n_components=1) pca.fit(X) X_reduction = pca.transform(X) # inverse_transform(低維數據):將低維數據升為高維數據 X_restore = pca.inverse_transform(X_reduction) plt.scatter(X_restore[:,0], X_restore[:,1]) plt.show()
2)例二(手寫識別數字數據集)
-
加載數據集(人為加載噪音:noisy_digits)
from sklearn import datasets digits = datasets.load_digits() X = digits.data y = digits.target # 在數據集 X 的基礎上創建一個帶噪音的數據集 noisy_digits = X + np.random.normal(0, 4, size=X.shape)
-
從帶有噪音的數據集 noisy_digits 中提出示例數據集 example_digits
example_digits = noisy_digits[y==0,:][:10] for num in range(1, 10): X_num = noisy_digits[y==num,:][:10] # np.vstack([array1, array2]):將兩個矩陣在水平方向相加,增加列數; # np.hstack([array1, array2]):將兩矩陣垂直相加,增加行數; example_digits = np.vstack([example_digits, X_num]) example_digits.shape # 輸出:(100, 64)
-
繪制示例數據集 example_digits(帶噪音)
def plot_digits(data): fig, axes = plt.subplots(10, 10, figsize=(10,10), subplot_kw = {'xticks':[], 'yticks':[]}, gridspec_kw=dict(hspace=0.1, wspace=0.1)) for i, ax in enumerate(axes.flat): ax.imshow(data[i].reshape(8, 8), cmap='binary', interpoltion='nearest', clim=(0, 16)) plt.show() plot_digits(example_digits)
-
降噪數據集 example_digits
# 如果噪音比較多,保留較少信息(此例中只保留 50% 的信息) pca = PCA(0.5) pca.fit(noisy_digits) # 查看最終的樣本維度 pca.n_components_ # 輸出:12 # 1)降維:將數據集 example_digits 降維,得到數據集 components components = pca.transform(example_digits) # 2)升維:將數據集升到原來維度(100, 64) filtered_digits = pca.inverse_transform(components) # 繪制去燥后的數據集 filtered_digits plot_digits(filtered_digits)