三分法求凸函数的极值

时间:2021-10-06 19:06:47

作者:jostree 转载请注明出处 http://www.cnblogs.com/jostree/p/4397990.html

在机器学习中,求凸函数的极值是一个常见的问题,常见的方法如梯度下降法,牛顿法等,今天我们介绍一种三分法来求一个凸函数的极值问题。

对于如下图的一个凸函数$f(x),x\in [left,right]$,其中lm和rm分别为区间[left,right]的三等分点,我们发现如果f(lm)<f(rm),那么函数值最小的点的横坐标x一定在[left,rm]之间。如果x在[rm,right]之间,就会出现在rm左右都有比他低的点,这显然是不可能的。 同理,当f(lm)>f(rm)时,最值的横坐标x一定在[lm,right]的区间内。

利用这个性质,我们就可以在缩小区间的同时向目标点逼近,从而得到极值。

三分法求凸函数的极值


举一个例子,题目源自http://hihocoder.com/contest/hiho40/problem/1,如下图在直角坐标系中有一条抛物线y=ax^2+bx+c和一个点P(x,y),求点P到抛物线的最短距离d,其中-200≤a,b,c,x,y≤200。我们另pivot代表抛物线的对称抽,可以发现当X>pivot,我们可以取left = pivot,right = inf, 反之left = -inf , right = pivot, 其距离恰好满足凸形函数。而我们要求的最短距离d,正好就是这个凸形函数的极值。

三分法求凸函数的极值

 

代码如下:

#include <stdlib.h>
#include
<stdio.h>
#include
<string.h>
#include
<limits.h>
#include
<iostream>
#include
<cmath>

using namespace std;
double a, b, c, x, y;
const double MAX = 100000;
double dis(double X)
{
double Y = a*X*X+b*X+c;
return sqrt((x-X)*(x-X)+(y-Y)*(y-Y));
}

double solve(double l, double r)
{
double lm = l + (r-l)/3;
double rm = r - (r-l)/3;
double lmd = dis(lm);
double rmd = dis(rm);
if( fabs(lmd - rmd) < 0.0001 )
{
return lmd;
}
if( lmd > rmd )
{
return solve(lm, r);
}
else
{
return solve(l, rm);
}
}

int main(int argc, char *argv[])
{
while( cin>>a>>b>>c>>x>>y )
{
double pivot = -b/(2*a);
double l = 0, r = 0;
if( pivot < x )
{
l
= pivot + 0.0001;
r
= MAX;
}
else
{
l
= -MAX;
r
= pivot - 0.0001;
}
double res = solve(l, r);
printf(
"%.3lf\n", res);
}
}