基於SGD、ASGD算法的SVM分類器(OpenCV案例源碼train_svmsgd.cpp解讀)


此案例用於二分類問題(鼠標左鍵、右鍵點出兩類點,會實時畫出分界線),最終得到一條分界線(直線):f(x)=weights*x+shift

源碼不再貼出,只講解最核心的doTrain()里的內容。參數含義翻譯自ml.hpp文件。

與SVM不同,SVMSGD不需要設置核函數。

【參數】默認值見下述代碼

模型類型:SGD、ASGD(推薦)。隨機梯度下降、平均隨機梯度下降。
邊界類型:HARD_MARGIN、SOFT_MARGIN(推薦),前者用於線性可分,后者用於非線性可分
邊界規范化 lambda:推薦設為0.0001(對於SGD),0.00001(對於ASGD)。越小,異類被拋棄的越少。
步長 gamma_0
步長降低力度 c:推薦設置為1(對於SGD),0.75(對於ASGD)
終止條件:TermCriteria::COUNT、TermCriteria::EPS、TermCriteria::COUNT + TermCriteria::EPS

參數設置函數:

setSvmsgdType()
setMarginType()
setMarginRegularization()
setInitialStepSize()
setStepDecreasingPower()

【使用方式】

cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();//創建對象
svmsgd->train(trainData);//訓練
svmsgd->save("MySvmsgd.xml");//保存模型
svmsgd->load("MySvmsgd.xml");//加載模型
svmsgd->predict(samples, responses);//預測,結果保存到responses標簽中

bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift)
{
    //*創建SVMSGD對象
    cv::Ptr<SVMSGD> svmsgd = SVMSGD::create(); //創建SVMSGD對象
    //*設置參數,以下全是默認參數
    //svmsgd->setSvmsgdType(SVMSGD::ASGD); //模型類型
    //svmsgd->setMarginType(SVMSGD::SOFT_MARGIN); //邊界類型
    //svmsgd->setMarginRegularization(0.00001); //邊界規范化
    //svmsgd->setInitialStepSize(0.05);//步長
    //svmsgd->setStepDecreasingPower(0.75); //步長減弱力度
    //svmsgd->setTermCriteria(TermCriteria(TermCriteria::COUNT,1000,1e-3));//終止條件,1000次迭代,0.001每次迭代的精度
    //*訓練集
    cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses);
    //*訓練
    svmsgd->train(trainData);

    if (svmsgd->isTrained()) //獲取分界線的系數,f(x)=weights*x+shift
    {
        weights = svmsgd->getWeights();//x系數
        shift = svmsgd->getShift();//常數項
        //*保存模型
        svmsgd->save("svmsgd.xml"); //保存訓練好的模型
        
        return true;
    }
    return false;
}

得到的xml中,weights有兩個數,shift有一個數。

 

 f(x)=weights*x+shift,不可以理解為y=kx+b,應該理解為Ax+By+C=0。weights的兩個數就是A、B,shift是C。

Mat weights(1, 2, CV_32FC1); weights是一個1*2的向量,x也是1*2的向量(xi,xj)也就是(x,y)坐標。

公式寫全了就是:f(x)=weights1*xi+weights2*xj+shift,其實就是weights與x這兩個向量的內積(對應相乘在求和)

f(x)如果等於0,說明點在此直線上,大於0就在線的一邊,小於0在線的另一邊。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM