SVM的使用train()


注意:數據結構的一致性,在高維度數據一般使用rbf核函數,使用網格搜索思想迭代求出gamma和c。

每行為一個樣本,數據類型都圍繞標黃代碼而定義的。

SVM訓練如下坐標(左邊一列為A類,右邊為B類),然后預測給出的坐標屬於哪一類。

#include<opencv2\opencv.hpp>
#include<iostream>
#include<opencv2\ml.hpp> //引入機器學習
using namespace cv;
using namespace std;
using namespace ml;

int main()
{
    //*1、類別標簽labelsMat,因為其是短整型,所以labels定義成int類型。最后再轉回char
    int labels[14] = { 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'B', 'B', 'B', 'B' };
    Mat labelsMat(14, 1, CV_32S);//短整型
    for (int i = 0; i < labelsMat.rows; i++)
    {
        labelsMat.at<int>(i, 0) = labels[i];
    }
    //*2、用於訓練的樣本集trainingDataMat
    int trainingData[14][2] = { { 110, 204 }, { 105, 306 }, { 102, 410 }, { 99, 511 }, { 93, 610 }, { 89, 713 }, { 89, 817 },
    { 173, 208 }, { 175, 313 }, { 167, 415 }, { 163, 514 }, { 160, 612 }, { 156, 716 }, { 152, 819 } };
    Mat trainingDataMat(14, 2, CV_32F); //float類型
    for (int i = 0; i < trainingDataMat.rows; i++)
    {
        for (int j = 0; j < trainingDataMat.cols; j++)
        {
            trainingDataMat.at<float>(i, j) = trainingData[i][j];
        }
    }
    //*3、初始化SVM,參數參考 https://blog.csdn.net/qq_27278957/article/details/88736516
    Ptr<ml::SVM> svm = ml::SVM::create();
    svm->setType(SVM::C_SVC); //svm的類型,
    svm->setKernel(SVM::LINEAR); //核函數
    svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, FLT_EPSILON)); //終止條件
    //*4、訓練模型
    Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);//訓練樣本的數據類型必須是CV_32F,標簽可以是CV_32S或其他。
    svm->train(tData);
    svm->save("svmData.xml");
    //*5、預測
    Mat tmp(1, 2, CV_32F);
    tmp.at<float>(0, 0) = 163;
    tmp.at<float>(0, 1) = 600;

    char label = (char)svm->predict(tmp); //ASCII碼轉字符,預測結果為B
    cout << label << endl;

    waitKey(0);
    return 0;
}

上圖繪制代碼:

Mat plot(900, 900, CV_8U);
vector<Point> myPoint(14);//14個點
for (int i = 0; i < myPoint.size(); i++)
{
    myPoint[i].x = trainingData[i][0];
    myPoint[i].y = trainingData[i][1];
    circle(plot, myPoint[i], 15, Scalar(255), -1);
}
namedWindow("坐標點", 0);
imshow("坐標點", plot);

 【參考】

https://blog.csdn.net/bigFatCat_Tom/article/details/95201903?depth_1-utm_source=distribute.pc_relevant.none-task&utm_source=distribute.pc_relevant.none-task


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM