【轉】OpenCV實現KNN算法


K Nearest Neighbors

這個算法首先貯藏所有的訓練樣本,然后通過分析(包括選舉,計算加權和等方式)一個新樣本周圍K個最近鄰以給出該樣本的相應值。這種方法有時候被稱作“基於樣本的學習”,即為了預測,我們對於給定的輸入搜索最近的已知其相應的特征向量。

class CvKNearest : public CvStatModel //繼承自ML庫中的統計模型基類
{
public:
 
    CvKNearest();//無參構造函數
    virtual ~CvKNearest();  //虛函數定義
 
    CvKNearest( const CvMat* _train_data, const CvMat* _responses,
                const CvMat* _sample_idx=0, bool _is_regression=false, int max_k=32 );//有參構造函數
 
    virtual bool train( const CvMat* _train_data, const CvMat* _responses,
                        const CvMat* _sample_idx=0, bool is_regression=false,
                        int _max_k=32, bool _update_base=false );
 
    virtual float find_nearest( const CvMat* _samples, int k, CvMat* results,
        const float** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const;
 
    virtual void clear();
    int get_max_k() const;
    int get_var_count() const;
    int get_sample_count() const;
    bool is_regression() const;
 
protected:
    ...
};

  

CvKNearest::train

訓練KNN模型

?
bool  CvKNearest::train( const  CvMat* _train_data, const  CvMat* _responses,
                         const  CvMat* _sample_idx=0, bool  is_regression= false ,
                         int  _max_k=32, bool  _update_base= false  );

這個類的方法訓練K近鄰模型。它遵循一個一般訓練方法約定的限制:只支持CV_ROW_SAMPLE數據格式,輸入向量必須都是有序的,而輸出可以 是 無序的(當is_regression=false),可以是有序的(is_regression=true)。並且變量子集和省略度量是不被支持的。

參數_max_k 指定了最大鄰居的個數,它將被傳給方法find_nearest。 參數 _update_base 指定模型是由原來的數據訓練(_update_base=false),還是被新訓練數據更新后再訓練(_update_base=true)。在后一種情況下_max_k 不能大於原值, 否則它會被忽略.

CvKNearest::find_nearest

尋找輸入向量的最近鄰

?
float  CvKNearest::find_nearest( const  CvMat* _samples, int  k, CvMat* results=0,
         const  float ** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const ;

對每個輸入向量(表示為matrix_sample的每一行),該方法找到k(k≤get_max_k() )個最近鄰。在回歸中,預測結果將是指定向量的近鄰的響應的均值。在分類中,類別將由投票決定。

對傳統分類和回歸預測來說,該方法可以有選擇的返回近鄰向量本身的指針(neighbors, array of k*_samples->rows pointers),它們相對應的輸出值(neighbor_responses, a vector of k*_samples->rows elements) ,和輸入向量與近鄰之間的距離(dist, also a vector of k*_samples->rows elements)。

對每個輸入向量來說,近鄰將按照它們到該向量的距離排序。

對單個輸入向量,所有的輸出矩陣是可選的,而且預測值將由該方法返回。

例程:使用kNN進行2維樣本集的分類,樣本集的分布為混合高斯分布

#include "ml.h"
#include "highgui.h"
 
int main( int argc, char** argv )
{
    const int K = 10;
    int i, j, k, accuracy;
    float response;
    int train_sample_count = 100;
    CvRNG rng_state = cvRNG(-1);
    CvMat* trainData = cvCreateMat( train_sample_count, 2, CV_32FC1 );
    CvMat* trainClasses = cvCreateMat( train_sample_count, 1, CV_32FC1 );
    IplImage* img = cvCreateImage( cvSize( 500, 500 ), 8, 3 );
    float _sample[2];
    CvMat sample = cvMat( 1, 2, CV_32FC1, _sample );
    cvZero( img );
 
    CvMat trainData1, trainData2, trainClasses1, trainClasses2;
 
    // form the training samples
    cvGetRows( trainData, &trainData1, 0, train_sample_count/2 );
    cvRandArr( &rng_state, &trainData1, CV_RAND_NORMAL, cvScalar(200,200), cvScalar(50,50) );
 
    cvGetRows( trainData, &trainData2, train_sample_count/2, train_sample_count );
    cvRandArr( &rng_state, &trainData2, CV_RAND_NORMAL, cvScalar(300,300), cvScalar(50,50) );
 
    cvGetRows( trainClasses, &trainClasses1, 0, train_sample_count/2 );
    cvSet( &trainClasses1, cvScalar(1) );
 
    cvGetRows( trainClasses, &trainClasses2, train_sample_count/2, train_sample_count );
    cvSet( &trainClasses2, cvScalar(2) );
 
    // learn classifier
    CvKNearest knn( trainData, trainClasses, 0, false, K );
    CvMat* nearests = cvCreateMat( 1, K, CV_32FC1);
 
    for( i = 0; i < img->height; i++ )
    {
        for( j = 0; j < img->width; j++ )
        {
            sample.data.fl[0] = (float)j;
            sample.data.fl[1] = (float)i;
 
            // estimates the response and get the neighbors' labels
            response = knn.find_nearest(&sample,K,0,0,nearests,0);
 
            // compute the number of neighbors representing the majority
            for( k = 0, accuracy = 0; k < K; k++ )
            {
                if( nearests->data.fl[k] == response)
                    accuracy++;
            }
            // highlight the pixel depending on the accuracy (or confidence)
            cvSet2D( img, i, j, response == 1 ?
                (accuracy > 5 ? CV_RGB(180,0,0) : CV_RGB(180,120,0)) :
                (accuracy > 5 ? CV_RGB(0,180,0) : CV_RGB(120,120,0)) );
        }
    }
 
    // display the original training samples
    for( i = 0; i < train_sample_count/2; i++ )
    {
        CvPoint pt;
        pt.x = cvRound(trainData1.data.fl[i*2]);
        pt.y = cvRound(trainData1.data.fl[i*2+1]);
        cvCircle( img, pt, 2, CV_RGB(255,0,0), CV_FILLED );
        pt.x = cvRound(trainData2.data.fl[i*2]);
        pt.y = cvRound(trainData2.data.fl[i*2+1]);
        cvCircle( img, pt, 2, CV_RGB(0,255,0), CV_FILLED );
    }
 
    cvNamedWindow( "classifier result", 1 );
    cvShowImage( "classifier result", img );
    cvWaitKey(0);
 
    cvReleaseMat( &trainClasses );
    cvReleaseMat( &trainData );
    return 0;
}

  結果:


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM