此案例用於二分類問題(鼠標左鍵、右鍵點出兩類點,會實時畫出分界線),最終得到一條分界線(直線):f(x)=weights*x+shift
源碼不再貼出,只講解最核心的doTrain()里的內容。參數含義翻譯自ml.hpp文件。
與SVM不同,SVMSGD不需要設置核函數。
【參數】默認值見下述代碼
模型類型:SGD、ASGD(推薦)。隨機梯度下降、平均隨機梯度下降。
邊界類型:HARD_MARGIN、SOFT_MARGIN(推薦),前者用於線性可分,后者用於非線性可分
邊界規范化 lambda:推薦設為0.0001(對於SGD),0.00001(對於ASGD)。越小,異類被拋棄的越少。
步長 gamma_0
步長降低力度 c:推薦設置為1(對於SGD),0.75(對於ASGD)
終止條件:TermCriteria::COUNT、TermCriteria::EPS、TermCriteria::COUNT + TermCriteria::EPS
參數設置函數:
setSvmsgdType()
setMarginType()
setMarginRegularization()
setInitialStepSize()
setStepDecreasingPower()
【使用方式】
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();//創建對象
svmsgd->train(trainData);//訓練
svmsgd->save("MySvmsgd.xml");//保存模型
svmsgd->load("MySvmsgd.xml");//加載模型
svmsgd->predict(samples, responses);//預測,結果保存到responses標簽中
bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift) { //*創建SVMSGD對象 cv::Ptr<SVMSGD> svmsgd = SVMSGD::create(); //創建SVMSGD對象 //*設置參數,以下全是默認參數 //svmsgd->setSvmsgdType(SVMSGD::ASGD); //模型類型 //svmsgd->setMarginType(SVMSGD::SOFT_MARGIN); //邊界類型 //svmsgd->setMarginRegularization(0.00001); //邊界規范化 //svmsgd->setInitialStepSize(0.05);//步長 //svmsgd->setStepDecreasingPower(0.75); //步長減弱力度 //svmsgd->setTermCriteria(TermCriteria(TermCriteria::COUNT,1000,1e-3));//終止條件,1000次迭代,0.001每次迭代的精度 //*訓練集 cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses); //*訓練 svmsgd->train(trainData); if (svmsgd->isTrained()) //獲取分界線的系數,f(x)=weights*x+shift { weights = svmsgd->getWeights();//x系數 shift = svmsgd->getShift();//常數項 //*保存模型 svmsgd->save("svmsgd.xml"); //保存訓練好的模型 return true; } return false; }
得到的xml中,weights有兩個數,shift有一個數。
f(x)=weights*x+shift,不可以理解為y=kx+b,應該理解為Ax+By+C=0。weights的兩個數就是A、B,shift是C。
Mat weights(1, 2, CV_32FC1); weights是一個1*2的向量,x也是1*2的向量(xi,xj)也就是(x,y)坐標。
公式寫全了就是:f(x)=weights1*xi+weights2*xj+shift,其實就是weights與x這兩個向量的內積(對應相乘在求和)
f(x)如果等於0,說明點在此直線上,大於0就在線的一邊,小於0在線的另一邊。