探索sklearn | 鳶尾花數據集


1 鳶尾花數據集背景

鳶尾花數據集是原則20世紀30年代的經典數據集。它是用統計進行分類的鼻祖。

sklearn包不僅囊括很多機器學習的算法,也自帶了許多經典的數據集,鳶尾花數據集就是其中之一。

導入的方法很簡單,不過我比較好奇它是如何來存儲這些數據的,於是我決定去背后看一看

from sklearn.datasets import load_iris

data = load_iris()

 找到sklearn包的路徑,發現包可不少,不過現在扔在一邊,以后再來探索,我現在要找到是datasets文件夾。

文件夾里沒有找到load_iris()這個函數在哪,只是在__init__文件里,發現了這么一行

from .base import load_iris

 

2 數據的內容

不出我料數據沒有存儲在程序文件里,而是用csv格式保存着,單獨放在了data文件夾里

150,4,setosa,versicolor,virginica
5.1,3.5,1.4,0.2,0 #花萼長度,花萼寬度,花瓣長度,花瓣寬度
4.9,3.0,1.4,0.2,0
4.7,3.2,1.3,0.2,0
4.6,3.1,1.5,0.2,0
5.0,3.6,1.4,0.2,0

 第一行首先記錄了樣本數目150,特征數目4

現在是時候來詳細介紹一下數據了:

數據包含三種鳶尾花的四個特征,分別是花萼長度(cm)、花萼寬度(cm)、花瓣長度(cm)、花瓣寬度(cm),這些形態特征在過去被用來識別物種。時至今日,我們已經可以通過基因簽名來識別這些分類了。

三種鳶尾花分別是

山鳶尾花(Iris Setosa)、

變色鳶尾花(Iris Versicolor)和

維吉尼亞鳶尾花(Iris Virginica)

 

3 數據可視化

鳶尾花數據集只有150個樣本,每個樣本只有4個特征,容易將其可視化

上面加載的data變量是一個類似字典的類型,是數據信息的集合,它像字典一樣通過鍵值對來組織信息

值既可以通過data['target']也可以通過data.target來獲取,很明顯這說明data並不是字典類型

data.keys()
>>['target_names', 'data', 'target', 'DESCR', 'feature_names']
feature = data['data'] #為numpy.ndarray類型
feature.shape #矩陣的行數和劣勢
>> (150L, 4L)
target = data['target']
target.shape
>>(150L,)

 

 四個特征是不可能同時在平面圖里畫出來的,只得運用我們的聰明才智,把它兩兩一組

def plot_iris_projection(x_index, y_index):
    for t,marker,c in zip(xrange(3),'>ox', 'rgb'):
        plt.scatter(data[target==t,x_index],
                    data[target==t,y_index],
                    marker=marker,c=c)
        plt.xlabel(feature_names[x_index])
        plt.ylabel(feature_names[y_index])

pairs = [(0,1),(0,2),(0,3),(1,2),(1,3),(2,3)] for i,(x_index,y_index) in enumerate(pairs): plt.subplot(2,3,i) plot_iris_projection(x_index, y_index) plt.show()

 

 

不難發現的是,不論在那兩個特征下,山鳶尾花都能很好的和其他兩種鳶尾花區分,但是另外兩種鳶尾花的特征比較焦灼,如果只有這四個特征,有時人都難以區分。

數據可視化最高只能是三維,matplotlib也能勝任此工作

from mpl_toolkits.mplot3d import Axes3D

def plot_iris_projection3d(x_index, y_index, z_index):
    fig = plt.figure()
    ax = fig.add_subplot(111,projection='3d')
    for t,marker,c in zip(xrange(3),'>ox', 'rgb'):
        ax.scatter(data[target==t,x_index],
                    data[target==t,y_index],
                    data[target==t,z_index],
                    marker=marker,c=c)
        ax.set_xlabel(feature_names[x_index])
        ax.set_ylabel(feature_names[y_index])
        ax.set_zlabel(feature_names[z_index])
        
plot_iris_projection3d(1, 2, 3)
plt.show()

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM