三分法求凸函數的極值


作者:jostree 轉載請注明出處 http://www.cnblogs.com/jostree/p/4397990.html

在機器學習中,求凸函數的極值是一個常見的問題,常見的方法如梯度下降法,牛頓法等,今天我們介紹一種三分法來求一個凸函數的極值問題。

對於如下圖的一個凸函數$f(x),x\in [left,right]$,其中lm和rm分別為區間[left,right]的三等分點,我們發現如果f(lm)<f(rm),那么函數值最小的點的橫坐標x一定在[left,rm]之間。如果x在[rm,right]之間,就會出現在rm左右都有比他低的點,這顯然是不可能的。 同理,當f(lm)>f(rm)時,最值的橫坐標x一定在[lm,right]的區間內。

利用這個性質,我們就可以在縮小區間的同時向目標點逼近,從而得到極值。


舉一個例子,題目源自http://hihocoder.com/contest/hiho40/problem/1,如下圖在直角坐標系中有一條拋物線y=ax^2+bx+c和一個點P(x,y),求點P到拋物線的最短距離d,其中-200≤a,b,c,x,y≤200。我們另pivot代表拋物線的對稱抽,可以發現當X>pivot,我們可以取left = pivot,right = inf, 反之left = -inf , right = pivot, 其距離恰好滿足凸形函數。而我們要求的最短距離d,正好就是這個凸形函數的極值。

 

代碼如下:

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <limits.h>
#include <iostream>
#include <cmath>

using namespace std;
double a, b, c, x, y;
const double MAX = 100000;
double dis(double X)
{
    double Y = a*X*X+b*X+c;
    return sqrt((x-X)*(x-X)+(y-Y)*(y-Y));
}

double solve(double l, double r)
{
    double lm = l + (r-l)/3;
    double rm = r - (r-l)/3;
    double lmd = dis(lm);
    double rmd = dis(rm);
    if( fabs(lmd - rmd) < 0.0001 )
    {
        return lmd;
    }
    if( lmd > rmd )
    {
        return solve(lm, r);
    }
    else
    {
        return solve(l, rm);
    }
}

int main(int argc, char *argv[])
{
    while( cin>>a>>b>>c>>x>>y )
    {
        double pivot = -b/(2*a);
        double l = 0, r = 0;
        if( pivot < x )
        {
            l = pivot + 0.0001;
            r = MAX;
        }
        else
        {
            l = -MAX;
            r = pivot - 0.0001;
        }
        double res = solve(l, r);
        printf("%.3lf\n", res);
    }
}

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM