KNN 實現mnist數據集分類


一 數據預處理

訓練數據集和驗證數據集分別為train.csv和test.csv。數據集下載地址:http://pan.baidu.com/s/1eQyIvZG

要分別對訓練數據集和驗證數據集進行分析,分析其內部數據的特征,下面分別對兩個數據集進行處理:

1.1 訓練數據集處理

train.csv 里面結構為42001 * 785。其中第一行為文字說明,應該去掉,其余每一行均表示一個圖像,大小為28*28,共784個像素值;第一列為類標簽,每一個標簽表示一個圖像所代表的數字,范圍為0-9;所以處理的步驟為:把所有數據存入列表中;刪除第一行,得到42000*785;分離開第一列和剩余數據,分別得到42000*1和42000*784兩個矩陣。

 

具體代碼如下:

def loadtraindata(trainfile):#傳參為所讀文件名
    l = list()#創建序列,要保存文件內容
    with open(trainfile,'rb') as filename:        
        lines = csv.reader(filename)
        for line in lines:
            l.append(line)
        del l[0]#刪除第一行
        l = np.array(l)#轉換為數組
        label = l[:,0]#取數組內所有行第一列元素
        data = l[:,1:]#取數組內所有行,從第二列至最后列元素
        label = np.int32(label)#int32為numpy 內部函數,進行數據類型轉換
        data = nomalizing(np.int32(data))#nomalizing 為自定義函數,進行數據標准化
        return data,label
        

標准化函數代碼如下:

def nomalizing(array):
    m,n = np.shape(array)#shape函數為得到數組的各個維度
    for i in xrange(m):
        for j in xrange(n):
            if array[i,j] != 0:
                array[i,j] = 1
    return array

二 KNN實現分類

現已知訓練集中有42000組元素和對應每組的類別,現給出一個未知類別的一組元素,要求預測其類別。KNN的做法是:找到與該組元素最近的k組;找到這k組元素里類別相同數最多的一個類別;認為該類別就是該未知類別元素的類別;

2.1 KNN具體代碼如下:

第一種代碼:

找到與該組元素最近的k組:這里面涉及到幾個點,1、如何判斷最近?有歐幾里得距離,曼哈頓距離。2、如何找到k組?即既要找到k個距離最小的組,同時要知道這些組的索引;因為這些組的索引用於知道他們的類別。所以我采用字典這個數據結構,既儲存距離,同時存儲索引。

def KNN(X,traindata,trainlabel,k):#X為未知類別數據,k為最近鄰個數
    find = dict()
    listkey = 0
    aa = [0,1,2,3,4,5,6,7,8,9]
    bb = [0,0,0,0,0,0,0,0,0,0]    
    m,n = np.shape(traindata)           
    for i in xrange(30000):#訓練數據集中前三萬組用於訓練
        sum = 0
        for j in xrange(n):
            sum += math.fabs(X[j]-traindata[i,j])
        if i < k:
            find[i]=sum
        else:
            for key in find.keys():
                if sum < find[key]:
                    del find[key]
                    find[i] = sum
                    break
    for key in find.keys():#找到這k個點類別相同數最多的類別
        for i in xrange(10):
            if trainlabel[key] == aa[i]:
                bb[i] += 1
                            
    for i in xrange(10):
        if bb[i] == max(bb):
            listkey = i
            break
            
    return aa[listkey]#返回該未知類型數據的預測類別

第二種代碼:

由於訓練集數據量較大,且都是數組之間的操作,可以使用numpy庫中array和mat函數進行處理,提高計算速度。

def KNN(X,traindata,trainlabel,k):
    X = np.mat(X)#這三行實現的是將序列轉換成矩陣,mat是numpy里轉換成矩陣的函數
    traindata = np.mat(traindata)
    trainlabel = np.mat(trainlabel)
    trainsize = traindata.shape[0]#得到traindata的第一維大小   
    distance = np.sum(np.array(np.tile(X,(trainsize,1))-traindata)**2,1)
    distancesort = distance.argsort()
    
    countdict = dict()
    for i in xrange(k):
        Xlabel= trainlabel[0,distancesort[i]]#distancesort存儲的是訓練數據集的索引,前k個為距離最小的k個點的索引,通過trainlabel得到k個點的類別
        countdict[Xlabel] = countdict.get(Xlabel,0) + 1#通過字典,存儲k個點上每個類別和對應的類別數量
    countlist = sorted(countdict.iteritems(),key=lambda x:x[1],reverse = True)#對字典的值,按降序排列,得到降序排列的存儲各元祖的序列
    return countlist[0][0]#其序列的第一個元組為類別數最多的元組,第一個元素為其類別。將該類別賦值給未知元組

 

2.2 計算召回率

def compute(traindata,trainlabel,k):
    error = 0
    for i in xrange(41800,42000):#用200組進行驗證
        X = traindata[i]    
        if KNN(X,traindata,trainlabel,k) != trainlabel[i]:
            error += 1
    
    return 1 - error / 200.00

 

三 補充:(涉及到的Python和Numpy語法細節)

3.1 Numpy

Numpy(一個用Python實現的科學計算包),包括:1、一個強大的N維數組對象Array;2、用於整合C/C++和Fortran代碼的工具包;3、實用的線性代數、傅里葉變換和隨機數生成函數。

3.1.1 生成數組

創建數組采用array函數,它接受一切序列型的對象,產生一個新的含有傳入數據的Numpy數組。

import numpy as np

a = [[2,3,4,5,6,7],[1,0,2,6,5,3]]
aa = np.array(a)

b = [1,3,5]
bb = np.array(b)

np.zeros(3)
np.zeros((4,5))

np.ones(3)
np.ones((4,5))

 

3.1.2 索引與切片

import numpy as np

a = [[2,3,4,5,6,7],[1,0,2,6,5,3],[2,4,5,6,4,3]]
aa = np.array(a)

aa[1]#索引
aa[1,2]

aa[:]#切片,沒有逗號默認只行切片
aa[:,:]
aa[1:,:1]#行是從第二行到最后,列是從開始到第二行(不包括)

aa[1:] = 2 #切片本質上不是復制,所以對它的修改會影響原數組
bb = aa[1:].copy() #復制切片,再修改b ,不會影響原數組

 

3.1.3 數組/矩陣轉置

import numpy as np

a = [[2,3,4,5,6,7],[1,0,2,6,5,3],[2,4,5,6,4,3]]
aa = np.mat(a)

aa.T#矩陣轉置,只有求T時aa可以是數組,其他都必須是矩陣
aa.H#矩陣共軛轉置
aa.I#矩陣的逆矩陣
aa.A#矩陣的二維視圖

 

3.1.4 數組與矩陣

 Matrix 類型繼承於ndarray類型,因此含有ndarray的所有屬性和方法。Matrix類型和ndarray類型常用的不同有:

a . Matrix對象是二維的。例子中mat之后bb為二維的矩陣

import numpy as np

b = [3,5,74,6]
bb = np.mat(b)
print bb[0,3]

b . Matrix類型的乘法覆蓋了array的乘法,使用的是矩陣的乘法運算。

import numpy as np

b = [3,5,74,6]
bb = np.array(b)
cc = np.mat(b)

bb*bb#數組乘法,為元素間相乘,即點乘
cc*cc.T#矩陣乘法,遵守前一個矩陣的列等於后一個矩陣的行這樣的矩陣運算規則

 

c . Matrix 類型的冪運算覆蓋了array的冪運算。

import numpy as np

b = [[3,5],[74,6]]
bb = np.array(b)
cc = np.mat(b)

print bb**2#數組的冪運算,是對每一個元素進行冪運算,bb不必須是行列相同
print cc**2#矩陣的冪運算,要求矩陣cc為方陣,然后進行方陣之間的矩陣運算

 

d . 矩陣具有轉置、共軛轉置、逆矩陣等特有屬性。

 

3.1.5 數組排序

a .  sorted 方法

python 的內置函數(built-in functions)

 sorted(...)
    sorted(iterable, cmp=None, key=None, reverse=False) --> new sorted list


iterable:是可迭代類型;
cmp:用於比較的函數,比較什么由key決定,有默認值,迭代集合中的一項;
key:用列表元素的某個屬性和函數進行作為關鍵字,有默認值,迭代集合中的一項;
reverse:排序規則. reverse = True 或者 reverse = False,有默認值。
返回值:是一個經過排序的可迭代類型,與iterable一樣。

import numpy as np

b = [[0,3],[2,2],[4,2]]

bb = np.array(b)

print sorted(bb,key=lambda x:x[1],reverse=False)#key指把bb代入x,對x第二維進行比較,reverse=True為降序,sorted排序不影響b
print sorted(bb,key = lambda x:(x[1],x[0]),reverse = True)#這里的key是先按第二維排序,再按第一維排序

b. sort方法

list的內置函數

sort(...)
 |      L.sort(cmp=None, key=None, reverse=False) -- stable sort *IN PLACE*;
 |      cmp(x, y) -> -1, 0, 1

b = [[0,3],[1,4],[4,0]]

b.sort(key=lambda x:x[1],reverse=False)
print b #對b調用sort函數會導致b的變化

 

c . argsort 方法

 argsort(a, axis=-1, kind='quicksort', order=None)
    Returns the indices that would sort an array.

a : 要排序的數組
axis : int or None, optional
        Axis along which to sort.  The default is -1 (the last axis). If None,
        the flattened array is used.

   axis = 0 按列排序(每列之間排序),axis= 1 按行排序(每行之間排序),默認行排序

kind : {'quicksort', 'mergesort', 'heapsort'}, optional
        Sorting algorithm.

import numpy as np

a = [1,3,5]
aa = np.array(a)

b = [[0,3],[2,2],[4,2]]
bb = np.array(b)

print np.argsort(aa)
print np.argsort(bb,0)
print np.argsort(bb,1)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM