【cs231n】圖像分類 k-Nearest Neighbor Classifier(K最近鄰分類器)【python3實現】


 【學習自CS231n課程】

 轉載請注明出處:http://www.cnblogs.com/GraceSkyer/p/8763616.html

 

 

k-Nearest Neighbor(KNN)分類器

  與其只找最相近的那1個圖片的標簽,我們找最相似的k個圖片的標簽,然后讓他們針對測試圖片進行投票,最后把票數最高的標簽作為對測試圖片的預測。所以當k=1的時候,k-Nearest Neighbor分類器就是Nearest Neighbor分類器。從直觀感受上就可以看到,更高的k值可以讓分類的效果更平滑,使得分類器對於異常值更有抵抗力。

  【或者說,用這種方法來檢索相鄰數據時,會對噪音產生更大的魯棒性,即噪音產生不確定的評價值對評價結果影響很小。】

 

  上面例子使用了訓練集中包含的2維平面的點來表示,點的顏色代表不同的類別或不同的類標簽,這里有五個類別。不同顏色區域代表分類器的決策邊界。這里我們用同樣的數據集,使用K=1的最近鄰分類器,以及K=3、K=5。K=1時(即最近鄰分類器),是根據相鄰的點來切割空間並進行着色;K=3時,綠色點簇中的黃色噪點不再會導致周圍的區域被划分為黃色,由於使用多數投票,中間的整個綠色區域,都將被分類為綠色;k=5時,藍色和紅色區域間的這些決策邊界變得更加平滑好看,它針對測試數據的泛化能力更好。

【注:白色區域表示 這個區域中沒有獲得K-最近鄰的投票,你或許可以假設將其歸為一個類別,表示這些點在這個區域沒有最近的其他點】

建議:去網站上嘗試用KNN分類你自己的數據,改變K值,改變距離度量,來培養決策邊界的直覺。

網址:http://vision.stanford.edu/teaching/cs231n-demos/knn/

 

超參數的選擇

  在實際中,大多使用k-NN分類器,但是K值如何確定?距離度量如何選擇?向K值和距離度量這樣的選擇,被稱為超參數(hyperparameter)。問題在於,在時間中該如何設置這些超參數。

  × 你首先可能想到的是,選擇能對你的訓練集給出最高准確率,表現最佳的超參數,這是糟糕的做法!例如:在之前的k-最近鄰分類算法中,假設k=1,我們總能完美分類訓練集數據,我們便采用了這個策略...但之前實踐中讓K取更大的值,盡管在訓練集中分錯個別數據,但對於在訓練集中未出現過的數據分類性能更佳,可見K=1並不合適。【在機器學習中,我們關心的不是要盡可能擬合訓練集,而是要讓我們的分類器以及方法,在訓練集以外的未知數據上表現得更好。】

   × 你或許又會想,把所有的數據分成兩部分:訓練集和測試集。然后在訓練集上用不同的超參數來訓練算法,然后將訓練好的分類器用在測試集上,再選擇一組在測試集上表現最好的超參數。這也很糟糕!如果采用這種方法,那么很可能我們選擇了一組超參,只是讓我們的算法在這組測試集上表現良好,但是這組測試集的表現無法代表在全新的數據上的表現。所以,不要用測試集去調整參數,容易使得你的模型過擬合。【機器學習系統的目的,是讓我們了解算法表現究竟如何,所以,測試集的目的是給我們一種預估方法,即在沒遇到的數據上算法表現將會如何。】

   更常見的做法,就是將數據分為三組:訓練集(大部分數據),驗證集(從訓練集中取出一小部分數據用來調優),測試集。我們在訓練集上用不同超參來訓練算法,在驗證集上進行評估,然后用一組超參(選擇在驗證集上表現最好的)。然后,當完成了這些步驟以后,再把這組在驗證集上表現最佳的分類器拿出來,在測試集上跑一跑。這個數據才是告訴你,你的算法在未見的新數據上表現如何。非常重要的一點是,必須分割驗證集和測試集,所以當我們做研究報告時,往往只是在最后一刻才會接觸到測試集。

  以CIFAR-10為例,我們可以用49000個圖像作為訓練集,用1000個圖像作為驗證集。驗證集其實就是作為假的測試集來調優。

代碼實現:

 1 import numpy as np
 2 import pickle
 3 import os
 4 
 5 
 6 class KNearestNeighbor(object):
 7     def __init__(self):
 8         pass
 9 
10     def train(self, X, y):
11         """ X is N x D where each row is an example. Y is 1-dimension of size N """
12         # the nearest neighbor classifier simply remembers all the training data
13         self.Xtr = X
14         self.ytr = y
15 
16     def predict(self, X, k=1):
17         """ X is N x D where each row is an example we wish to predict label for """
18         """ k is the number of nearest neighbors that vote for the predicted labels."""
19         num_test = X.shape[0]
20         # lets make sure that the output type matches the input type
21         Ypred = np.zeros(num_test, dtype=self.ytr.dtype)
22 
23         # loop over all test rows
24         for i in range(num_test):
25             # using the L1 distance (sum of absolute value differences)
26             distances = np.sum(np.abs(self.Xtr - X[i, :]), axis=1)
27             # L2 distance:
28             # distances = np.sqrt(np.sum(np.square(self.Xtr - X[i, :]), axis=1))
29             indexes = np.argsort(distances)
30             Yclosest = self.ytr[indexes[:k]]
31             cnt = np.bincount(Yclosest)
32             Ypred[i] = np.argmax(cnt)
33 
34         return Ypred
35 
36 
37 def load_CIFAR_batch(file):
38     """ load single batch of cifar """
39     with open(file, 'rb') as f:
40         datadict = pickle.load(f, encoding='latin1')
41         X = datadict['data']
42         Y = datadict['labels']
43         X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
44         Y = np.array(Y)
45     return X, Y
46 
47 
48 def load_CIFAR10(ROOT):
49     """ load all of cifar """
50     xs = []
51     ys = []
52     for b in range(1, 6):
53         f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
54         X, Y = load_CIFAR_batch(f)
55         xs.append(X)
56         ys.append(Y)
57     Xtr = np.concatenate(xs)  # 使變成行向量
58     Ytr = np.concatenate(ys)
59     del X, Y
60     Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
61     return Xtr, Ytr, Xte, Yte
62 
63 
64 Xtr, Ytr, Xte, Yte = load_CIFAR10('data/cifar-10-batches-py/')  # a magic function we provide
65 # flatten out all images to be one-dimensional
66 Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3)  # Xtr_rows becomes 50000 x 3072
67 Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3)  # Xte_rows becomes 10000 x 3072
68 
69 
70 # assume we have Xtr_rows, Ytr, Xte_rows, Yte as before
71 # recall Xtr_rows is 50,000 x 3072 matrix
72 Xval_rows = Xtr_rows[:1000, :]  # take first 1000 for validation
73 Yval = Ytr[:1000]
74 Xtr_rows = Xtr_rows[1000:, :]  # keep last 49,000 for train
75 Ytr = Ytr[1000:]
76 
77 # find hyperparameters that work best on the validation set
78 validation_accuracies = []
79 for k in [1, 3, 5, 10, 20, 50, 100]:
80     # use a particular value of k and evaluation on validation data
81     nn = KNearestNeighbor()
82     nn.train(Xtr_rows, Ytr)
83     # here we assume a modified NearestNeighbor class that can take a k as input
84     Yval_predict = nn.predict(Xval_rows, k=k)
85     acc = np.mean(Yval_predict == Yval)
86     print('accuracy: %f' % (acc,))
87 
88     # keep track of what works on the validation set
89     validation_accuracies.append((k, acc))
View Code

這代碼我有空的話再完善吧......沒空(

 

程序結束后,我們會作圖分析出哪個k值表現最好,然后用這個k值來跑真正的測試集,並作出對算法的評價。

把訓練集分成訓練集和驗證集。使用驗證集來對所有超參數調優。最后只在測試集上跑一次並報告結果。

 

交叉驗證

  設定超參數的另一個策略是交叉驗證。這在小數據集中更常用一些,在深度學習中不那么常用。當我們訓練大型模型時,訓練本身非常消耗計算能力,因此這個方法實際上不常用。

  還是用剛才的例子,如果是交叉驗證集,我們就不是取1000個圖像,而是將訓練集平均分成5份,其中4份用來訓練,1份用來驗證。然后我們循環着取其中4份來訓練,其中1份來驗證,最后取所有5次驗證結果的平均值作為算法驗證結果。

下圖是5份交叉驗證對k值調優的例子。針對每個k值,得到5個准確率結果,取其平均值,然后對不同k值的平均表現畫線連接。本例中,當k=7的時算法表現最好(對應圖中的准確率峰值)。如果我們將訓練集分成更多份數,直線一般會更加平滑(噪音更少)。

實際應用:

  在實際情況下,人們不是很喜歡用交叉驗證,主要是因為它會耗費較多的計算資源。一般直接把訓練集按照50%-90%的比例分成訓練集和驗證集。但這也是根據具體情況來定的:如果超參數數量多,你可能就想用更大的驗證集,而驗證集的數量不夠,那么最好還是用交叉驗證吧。至於分成幾份比較好,一般都是分成3、5和10份。

  常用的數據分割模式。給出訓練集和測試集后,訓練集一般會被均分。這里是分成5份。前面4份用來訓練,黃色那份用作驗證集調優。如果采取交叉驗證,那就各份輪流作為驗證集。最后模型訓練完畢,超參數都定好了,讓模型跑一次(而且只跑一次)測試集,以此測試結果評價算法。

 

KNN的劣勢:

其實,KNN在圖像分類中很少用到,原因如下:

  • 它在測試時運算時間很長,這和我們剛才提到的需求不符,
  • 像歐幾里得距離或者L1距離這樣的衡量標准用在比較圖像上不太合適,這種向量化的距離函數,不太適合表示圖像之間視覺的相似度。
  • 它並不能體現圖像之間的語義差別,更多的是圖像的背景,色彩的差異。
  • K-近鄰算法還有另一個問題,我們稱之為“維度災難”。 K-近鄰分類器,它有點像是用訓練數據 把樣本空間分成幾塊,這意味着,如果我們希望分類器有好的效果,我們需要訓練數據能密集地分布在空間中,否則最近鄰點的實際距離可能很遠,也就是說,和待測樣本的相似性沒有那么高。而問題在於,想要密集地分布在空間中,我們需要指數倍地訓練數據,這很糟糕,我們根本不可能拿到那么多的圖片去密布這樣的高維空間里的像素。

 

 

參考:

https://www.bilibili.com/video/av17204303/?from=search&seid=6625954842411789830

https://zhuanlan.zhihu.com/p/20900216?refer=intelligentunit

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2026 CODEPRJ.COM