K Nearest Neighbors
這個算法首先貯藏所有的訓練樣本,然后通過分析(包括選舉,計算加權和等方式)一個新樣本周圍K個最近鄰以給出該樣本的相應值。這種方法有時候被稱作“基於樣本的學習”,即為了預測,我們對於給定的輸入搜索最近的已知其相應的特征向量。
class CvKNearest : public CvStatModel //繼承自ML庫中的統計模型基類 { public: CvKNearest();//無參構造函數 virtual ~CvKNearest(); //虛函數定義 CvKNearest( const CvMat* _train_data, const CvMat* _responses, const CvMat* _sample_idx=0, bool _is_regression=false, int max_k=32 );//有參構造函數 virtual bool train( const CvMat* _train_data, const CvMat* _responses, const CvMat* _sample_idx=0, bool is_regression=false, int _max_k=32, bool _update_base=false ); virtual float find_nearest( const CvMat* _samples, int k, CvMat* results, const float** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const; virtual void clear(); int get_max_k() const; int get_var_count() const; int get_sample_count() const; bool is_regression() const; protected: ... };
CvKNearest::train
訓練KNN模型
bool
CvKNearest::train(
const
CvMat* _train_data,
const
CvMat* _responses,
const
CvMat* _sample_idx=0,
bool
is_regression=
false
,
int
_max_k=32,
bool
_update_base=
false
);
|
這個類的方法訓練K近鄰模型。它遵循一個一般訓練方法約定的限制:只支持CV_ROW_SAMPLE數據格式,輸入向量必須都是有序的,而輸出可以 是 無序的(當is_regression=false),可以是有序的(is_regression=true)。並且變量子集和省略度量是不被支持的。
參數_max_k 指定了最大鄰居的個數,它將被傳給方法find_nearest。 參數 _update_base 指定模型是由原來的數據訓練(_update_base=false),還是被新訓練數據更新后再訓練(_update_base=true)。在后一種情況下_max_k 不能大於原值, 否則它會被忽略.
CvKNearest::find_nearest
尋找輸入向量的最近鄰
float
CvKNearest::find_nearest(
const
CvMat* _samples,
int
k, CvMat* results=0,
const
float
** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 )
const
;
|
對每個輸入向量(表示為matrix_sample的每一行),該方法找到k(k≤get_max_k() )個最近鄰。在回歸中,預測結果將是指定向量的近鄰的響應的均值。在分類中,類別將由投票決定。
對傳統分類和回歸預測來說,該方法可以有選擇的返回近鄰向量本身的指針(neighbors, array of k*_samples->rows pointers),它們相對應的輸出值(neighbor_responses, a vector of k*_samples->rows elements) ,和輸入向量與近鄰之間的距離(dist, also a vector of k*_samples->rows elements)。
對每個輸入向量來說,近鄰將按照它們到該向量的距離排序。
對單個輸入向量,所有的輸出矩陣是可選的,而且預測值將由該方法返回。
例程:使用kNN進行2維樣本集的分類,樣本集的分布為混合高斯分布
#include "ml.h" #include "highgui.h" int main( int argc, char** argv ) { const int K = 10; int i, j, k, accuracy; float response; int train_sample_count = 100; CvRNG rng_state = cvRNG(-1); CvMat* trainData = cvCreateMat( train_sample_count, 2, CV_32FC1 ); CvMat* trainClasses = cvCreateMat( train_sample_count, 1, CV_32FC1 ); IplImage* img = cvCreateImage( cvSize( 500, 500 ), 8, 3 ); float _sample[2]; CvMat sample = cvMat( 1, 2, CV_32FC1, _sample ); cvZero( img ); CvMat trainData1, trainData2, trainClasses1, trainClasses2; // form the training samples cvGetRows( trainData, &trainData1, 0, train_sample_count/2 ); cvRandArr( &rng_state, &trainData1, CV_RAND_NORMAL, cvScalar(200,200), cvScalar(50,50) ); cvGetRows( trainData, &trainData2, train_sample_count/2, train_sample_count ); cvRandArr( &rng_state, &trainData2, CV_RAND_NORMAL, cvScalar(300,300), cvScalar(50,50) ); cvGetRows( trainClasses, &trainClasses1, 0, train_sample_count/2 ); cvSet( &trainClasses1, cvScalar(1) ); cvGetRows( trainClasses, &trainClasses2, train_sample_count/2, train_sample_count ); cvSet( &trainClasses2, cvScalar(2) ); // learn classifier CvKNearest knn( trainData, trainClasses, 0, false, K ); CvMat* nearests = cvCreateMat( 1, K, CV_32FC1); for( i = 0; i < img->height; i++ ) { for( j = 0; j < img->width; j++ ) { sample.data.fl[0] = (float)j; sample.data.fl[1] = (float)i; // estimates the response and get the neighbors' labels response = knn.find_nearest(&sample,K,0,0,nearests,0); // compute the number of neighbors representing the majority for( k = 0, accuracy = 0; k < K; k++ ) { if( nearests->data.fl[k] == response) accuracy++; } // highlight the pixel depending on the accuracy (or confidence) cvSet2D( img, i, j, response == 1 ? (accuracy > 5 ? CV_RGB(180,0,0) : CV_RGB(180,120,0)) : (accuracy > 5 ? CV_RGB(0,180,0) : CV_RGB(120,120,0)) ); } } // display the original training samples for( i = 0; i < train_sample_count/2; i++ ) { CvPoint pt; pt.x = cvRound(trainData1.data.fl[i*2]); pt.y = cvRound(trainData1.data.fl[i*2+1]); cvCircle( img, pt, 2, CV_RGB(255,0,0), CV_FILLED ); pt.x = cvRound(trainData2.data.fl[i*2]); pt.y = cvRound(trainData2.data.fl[i*2+1]); cvCircle( img, pt, 2, CV_RGB(0,255,0), CV_FILLED ); } cvNamedWindow( "classifier result", 1 ); cvShowImage( "classifier result", img ); cvWaitKey(0); cvReleaseMat( &trainClasses ); cvReleaseMat( &trainData ); return 0; }
結果: