1 鳶尾花數據集背景
鳶尾花數據集是原則20世紀30年代的經典數據集。它是用統計進行分類的鼻祖。
sklearn包不僅囊括很多機器學習的算法,也自帶了許多經典的數據集,鳶尾花數據集就是其中之一。
導入的方法很簡單,不過我比較好奇它是如何來存儲這些數據的,於是我決定去背后看一看
from sklearn.datasets import load_iris data = load_iris()
找到sklearn包的路徑,發現包可不少,不過現在扔在一邊,以后再來探索,我現在要找到是datasets文件夾。
文件夾里沒有找到load_iris()這個函數在哪,只是在__init__文件里,發現了這么一行
from .base import load_iris
2 數據的內容
不出我料數據沒有存儲在程序文件里,而是用csv格式保存着,單獨放在了data文件夾里
150,4,setosa,versicolor,virginica 5.1,3.5,1.4,0.2,0 #花萼長度,花萼寬度,花瓣長度,花瓣寬度 4.9,3.0,1.4,0.2,0 4.7,3.2,1.3,0.2,0 4.6,3.1,1.5,0.2,0 5.0,3.6,1.4,0.2,0
第一行首先記錄了樣本數目150,特征數目4
現在是時候來詳細介紹一下數據了:
數據包含三種鳶尾花的四個特征,分別是花萼長度(cm)、花萼寬度(cm)、花瓣長度(cm)、花瓣寬度(cm),這些形態特征在過去被用來識別物種。時至今日,我們已經可以通過基因簽名來識別這些分類了。
三種鳶尾花分別是
山鳶尾花(Iris Setosa)、
變色鳶尾花(Iris Versicolor)和
維吉尼亞鳶尾花(Iris Virginica)
3 數據可視化
鳶尾花數據集只有150個樣本,每個樣本只有4個特征,容易將其可視化
上面加載的data變量是一個類似字典的類型,是數據信息的集合,它像字典一樣通過鍵值對來組織信息
值既可以通過data['target']也可以通過data.target來獲取,很明顯這說明data並不是字典類型
data.keys() >>['target_names', 'data', 'target', 'DESCR', 'feature_names'] feature = data['data'] #為numpy.ndarray類型 feature.shape #矩陣的行數和劣勢 >> (150L, 4L) target = data['target'] target.shape >>(150L,)
四個特征是不可能同時在平面圖里畫出來的,只得運用我們的聰明才智,把它兩兩一組
def plot_iris_projection(x_index, y_index):
for t,marker,c in zip(xrange(3),'>ox', 'rgb'):
plt.scatter(data[target==t,x_index],
data[target==t,y_index],
marker=marker,c=c)
plt.xlabel(feature_names[x_index])
plt.ylabel(feature_names[y_index])
pairs = [(0,1),(0,2),(0,3),(1,2),(1,3),(2,3)]
for i,(x_index,y_index) in enumerate(pairs):
plt.subplot(2,3,i)
plot_iris_projection(x_index, y_index)
plt.show()
不難發現的是,不論在那兩個特征下,山鳶尾花都能很好的和其他兩種鳶尾花區分,但是另外兩種鳶尾花的特征比較焦灼,如果只有這四個特征,有時人都難以區分。
數據可視化最高只能是三維,matplotlib也能勝任此工作
from mpl_toolkits.mplot3d import Axes3D
def plot_iris_projection3d(x_index, y_index, z_index):
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
for t,marker,c in zip(xrange(3),'>ox', 'rgb'):
ax.scatter(data[target==t,x_index],
data[target==t,y_index],
data[target==t,z_index],
marker=marker,c=c)
ax.set_xlabel(feature_names[x_index])
ax.set_ylabel(feature_names[y_index])
ax.set_zlabel(feature_names[z_index])
plot_iris_projection3d(1, 2, 3)
plt.show()
