[OpenCV隨筆]-OpenCV3.x中SVM多分類使用(代碼篇)


1. SVM介紹

占個坑,以后再說

2. OpenCV3.x下SVM接口介紹

官方文檔
OpenCV3.x與OpenCV2.x中SVM的接口有了很大變化,在接口上使用了虛函數取代以前的定義。
下面介紹幾個常用的接口,及其參數意義。

2.1 初始化函數

定義如下:

CV_WRAP static Ptr<SVM> create();

2.2 參數設置函數

然后是一些設置SVM參數的函數:

CV_WRAP virtual int getType() const = 0;
CV_WRAP virtual void setType(int val) = 0;

CV_WRAP virtual double getGamma() const = 0;
CV_WRAP virtual void setGamma(double val) = 0;

CV_WRAP virtual double getDegree() const = 0;
CV_WRAP virtual void setDegree(double val) = 0;

CV_WRAP virtual double getC() const = 0;
CV_WRAP virtual void setC(double val) = 0;

CV_WRAP virtual double getNu() const = 0;
CV_WRAP virtual void setNu(double val) = 0;

CV_WRAP virtual double getP() const = 0;
CV_WRAP virtual void setP(double val) = 0;

CV_WRAP virtual cv::Mat getClassWeights() const = 0;
CV_WRAP virtual void setClassWeights(const cv::Mat &val) = 0;

CV_WRAP virtual cv::TermCriteria getTermCriteria() const = 0;
CV_WRAP virtual void setTermCriteria(const cv::TermCriteria &val) = 0;

CV_WRAP virtual int getKernelType() const = 0;
CV_WRAP virtual void setKernel(int kernelType) = 0;

具體的作用可以參考OpenCV文檔,這里只介紹兩個常用的函數:

//設置SVM類型
CV_WRAP virtual int getType() const = 0;

這個函數用於設置SVM類型,OpenCV提供了五種類型:

Types { 
    //C類支持向量分類機。 n類分組 (n≥2),容許用異常值處罰因子C進行不完全分類。
    C_SVC =100, 

    //$v$類支持向量機
    NU_SVC =101, 

    //單分類器,所有的練習數據提取自同一個類里,
    //然后SVM建樹了一個分界線以分別該類在特點空間
    //中所占區域和其它類在特點空間中所占區域。
    ONE_CLASS =102, 

    EPS_SVR =103, 

    NU_SVR =104 
}

一般我們使用SVM進行二分類或者多分類任務,選擇第一種SVM::C_SVC即可。
還有一個函數就是:

CV_WRAP virtual void setKernel(int kernelType) = 0;

這個函數用於設置SVM的核函數類型,我們知道,通過選擇SVM的核函數可以使SVM處理高階、非線性問題。OpenCV提供幾種核函數:

enum KernelTypes {
    /** Returned by SVM::getKernelType in case when custom kernel has been set */
    CUSTOM=-1,
    
    //線性核
    LINEAR=0,
    
    //多項式核
    POLY=1,
    
    //徑向基核(高斯核)
    RBF=2,
    
    //sigmoid核
    SIGMOID=3,
    
    //指數核,與高斯核類似
    CHI2=4,
    
    //直方圖核
    INTER=5
};

一般情況下使用徑向基核可以很好處理大部分情況。

2.3 訓練函數

OpenCV3.x中SVM的提供了訓練函數也與2.x不同,如下:

virtual bool trainAuto( const Ptr<TrainData>& data, int kFold = 10,
                ParamGrid Cgrid = getDefaultGrid(C),
                ParamGrid gammaGrid  = getDefaultGrid(GAMMA),
                ParamGrid pGrid      = getDefaultGrid(P),
                ParamGrid nuGrid     = getDefaultGrid(NU),
                ParamGrid coeffGrid  = getDefaultGrid(COEF),
                ParamGrid degreeGrid = getDefaultGrid(DEGREE),
                bool balanced=false) = 0;

bool trainAuto (InputArray samples, int layout, InputArray responses, 
                int kFold=10, Ptr< ParamGrid > Cgrid=SVM::getDefaultGridPtr(SVM::C), 
                Ptr< ParamGrid > gammaGrid=SVM::getDefaultGridPtr(SVM::GAMMA), 
                Ptr< ParamGrid > pGrid=SVM::getDefaultGridPtr(SVM::P), 
                Ptr< ParamGrid > nuGrid=SVM::getDefaultGridPtr(SVM::NU), 
                Ptr< ParamGrid > coeffGrid=SVM::getDefaultGridPtr(SVM::COEF), 
                Ptr< ParamGrid > degreeGrid=SVM::getDefaultGridPtr(SVM::DEGREE), 
                bool balanced=false)

trainAuto可以在訓練過程中自動優化2.2中的那些參數,而使用train函數時,參數被固定,所以推薦使用trainAuto函數。
在准備訓練數據的時候,有下面幾點需要注意,否則函數會報錯

  1. SVM的訓練函數是ROW_SAMPLE類型的,也就是說,送入SVM訓練的特征需要reshape成一個行向量,所有訓練數據全部保存在一個Mat中,一個訓練樣本就是Mat中的一行,最后還要講這個Mat轉換成CV_32F類型,例如,如果有\(k\)個樣本,每個樣本原本維度是\((h, w)\),則轉換后Mat的維度為\((k, h * w)\)
  2. 對於多分類問題,label矩陣的行數要與樣本數量一致,也就是每個樣本要在label矩陣中有一個對應的標簽,label的列數為1,因為對於一個樣本,SVM輸出一個值,我們在訓練前需要做的就是設計這個值與樣本的對應關系。對於有\(k\)個樣本的情況,label的維度是\((k, 1)\)

2.4 預測函數

函數定義如下:

float predict(cv::InputArrat samples, cv::OutputArray results = noArray(), int flags = 0) const;

其中samples就是需要預測的樣本,這里樣本同樣要轉換成ROW_SAMPLE和CV_32F格式,對於單個測試樣本的情況,預測結果直接通過函數返回值返回,而如果samples中有多個樣本,就需要穿進result參數,預測結果以列向量的方式保存在result數組中。假如有\(k\)個樣本,每個樣本原本的維度為\((h, w)\),則samples的維度為\((k, h * w)\),最終預測結果result維度為\((k, 1)\)

3. 例程

下面上代碼:

/*
* 把圖片從vector<Mat>格式轉換成SVM的RAW_SAMPLE格式
*/
void transform(const vector<Mat> &split, Mat &testData)
{
    for (auto it = split.begin(); it != split.end(); it++){
        Mat tmp;
        resize(*it, tmp, Size(28, 28));
        testData.push_back(tmp.reshape(0, 1));
    }

    testData.convertTo(testData, CV_32F);
}

/*
* 從文件list.txt中讀取測試數據和標簽,輸出SVM的Mat格式
*/
void get_data(string path, Mat &trainData, Mat &trainLabels)
{
    fstream io(path, ios::in);
    if (!io.is_open()){
        cout << "file open error in path : " << path << endl;
        exit(0);
    }

    while (!io.eof())
    {
        string msg;
        io >> msg;

        trainData.push_back(imread(msg, 0).reshape(0, 1));

        io >> msg;
        int idx = msg[0] - '0';
        //trainLabels.push_back(Mat_<int>(1, 1) << idx);  //用這種方式會報錯,原因尚且不明
        trainLabels.push_back(Mat(1, 1, CV_32S, &idx));
    }

    trainData.convertTo(trainData, CV_32F);
}

/*
* 訓練SVM
*/
void svm_train(Ptr<SVM> &model, Mat &trainData, Mat &trainLabels)
{
    model->setType(SVM::C_SVC);     //SVM類型
    model->setKernel(SVM::LINEAR);  //核函數,這里使用線性核

    Ptr<TrainData> tData = TrainData::create(trainData, ROW_SAMPLE, trainLabels);

    cout << "SVM: start train ..." << endl;
    model->trainAuto(tData);
    cout << "SVM: train success ..." << endl;
}

/*
* 利用訓練好的SVM預測
*/
void svm_pridect(Ptr<SVM> &model, Mat test)
{
    Mat result;
    float rst = model->predict(test, result);
    for (auto i = 0; i < result.rows; i++){
        cout << result.at<float>(i, 0);
    }
}

int main(int argc, const char** argv)
{
    fstream io;
    io.open("test_list.txt", ios::in);

    string train_path = "train_list.txt";
        
    vector<Mat> test_set;
    get_test(io, test_set);

    Ptr<SVM> model = SVM::create();
    Mat trainData, trainLabels;
    get_data(train_path, trainData, trainLabels);
    svm_train(model, trainData, trainLabels);

    Mat testData;
    transform(test_set, testData);
    svm_pridect(model, testData);
}

trian_list.txt文件格式是這樣的:

D:\ImgPro\Project\for\char\code\beta00\train_data\0\0-1.jpg		0
D:\ImgPro\Project\for\char\code\beta00\train_data\0\0-2.jpg		0

每行前一段表示訓練圖片地址,最后的數字表示這個圖片對應標簽
test_list.txt中格式與train_list.txt差不多,只是沒有了標簽。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM