注意:數據結構的一致性,在高維度數據一般使用rbf核函數,使用網格搜索思想迭代求出gamma和c。
每行為一個樣本,數據類型都圍繞標黃代碼而定義的。
SVM訓練如下坐標(左邊一列為A類,右邊為B類),然后預測給出的坐標屬於哪一類。
#include<opencv2\opencv.hpp> #include<iostream> #include<opencv2\ml.hpp> //引入機器學習 using namespace cv; using namespace std; using namespace ml; int main() { //*1、類別標簽labelsMat,因為其是短整型,所以labels定義成int類型。最后再轉回char int labels[14] = { 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'B', 'B', 'B', 'B' }; Mat labelsMat(14, 1, CV_32S);//短整型 for (int i = 0; i < labelsMat.rows; i++) { labelsMat.at<int>(i, 0) = labels[i]; } //*2、用於訓練的樣本集trainingDataMat int trainingData[14][2] = { { 110, 204 }, { 105, 306 }, { 102, 410 }, { 99, 511 }, { 93, 610 }, { 89, 713 }, { 89, 817 }, { 173, 208 }, { 175, 313 }, { 167, 415 }, { 163, 514 }, { 160, 612 }, { 156, 716 }, { 152, 819 } }; Mat trainingDataMat(14, 2, CV_32F); //float類型 for (int i = 0; i < trainingDataMat.rows; i++) { for (int j = 0; j < trainingDataMat.cols; j++) { trainingDataMat.at<float>(i, j) = trainingData[i][j]; } } //*3、初始化SVM,參數參考 https://blog.csdn.net/qq_27278957/article/details/88736516 Ptr<ml::SVM> svm = ml::SVM::create(); svm->setType(SVM::C_SVC); //svm的類型, svm->setKernel(SVM::LINEAR); //核函數 svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, FLT_EPSILON)); //終止條件 //*4、訓練模型 Ptr<TrainData> tData = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);//訓練樣本的數據類型必須是CV_32F,標簽可以是CV_32S或其他。 svm->train(tData); svm->save("svmData.xml"); //*5、預測 Mat tmp(1, 2, CV_32F); tmp.at<float>(0, 0) = 163; tmp.at<float>(0, 1) = 600; char label = (char)svm->predict(tmp); //ASCII碼轉字符,預測結果為B cout << label << endl; waitKey(0); return 0; }
上圖繪制代碼:
Mat plot(900, 900, CV_8U); vector<Point> myPoint(14);//14個點 for (int i = 0; i < myPoint.size(); i++) { myPoint[i].x = trainingData[i][0]; myPoint[i].y = trainingData[i][1]; circle(plot, myPoint[i], 15, Scalar(255), -1); } namedWindow("坐標點", 0); imshow("坐標點", plot);
【參考】