算法原理
to-do
Matlab代碼
clc; clear;
f = @(x) x(1).^2+2*x(1)*x(2)+3*x(2).^2; %待求函數,x1,x2,x3...
% f = @(x) x(1).^2+2*x(2).^2;
paraNum = 2; %函數參數的個數,x1,x2,x3...的個數
x0 = [3,3]; %初始值
tol = 1e-5; %迭代容忍度
flag = inf; %結束條件
error = []; %函數變化
while flag > tol
p = g(f,x0,paraNum); %列向量
f2 = @(a) f(x0-a*p');
buChang = argmin(f2); %求步長,line search:argmin function
x1 = x0-buChang*p';
flag = norm(x1-x0);
error = [error,flag];
x0 = x1;
end
plot(0:length(error)-1,error)
function [f_grad] = g(f,x0,paraNum)
temp = sym('x',[1,paraNum]);
f1=f(temp);
Z = gradient(f1);
f_grad = double(subs(Z,temp,x0));
end
function [x] = argmin(f)
%求步長
t = 0;
options = optimset('Display','off');
[x,~] = fminunc(f,t,options);
end
代碼問題
- Matlab符號運算,耗時
- 最速下降法的步長使用line-search,耗時
代碼改進
clc; clear;
f = @(x) x(1).^2+2*x(1)*x(2)+3*x(2).^2; %待求函數,x1,x2,x3...
% f = @(x) x(1).^2+2*x(2).^2;
paraNum = 2; %函數參數的個數,x1,x2,x3...的個數
x0 = [3,3]; %初始值
tol = 1e-3; %迭代容忍度
flag = inf; %結束條件
error = []; %函數變化
while flag > tol
% for i =1:1
p = g(f,x0,paraNum); %列向量
if norm(p) < tol
buChang = 0;
else
buChang = argmin(f,x0,p,paraNum); %求步長,line search:argmin function
end
x1 = x0-buChang.*p';
flag = norm(x1-x0);
error = [error,flag];
x0 = x1;
end
plot(0:length(error)-1,error)
function [f_grad] = g(f,x0,paraNum)
temp = sym('x',[1,paraNum]);
f1=f(temp);
Z = gradient(f1);
f_grad = double(subs(Z,temp,x0));
end
% function [x] = argmin(f,paraNum)
% %求步長
% t = zeros(1,paraNum);
% options = optimset('Display','off');
% [x,~] = fminunc(f,t,options);
% end
function [x] = argmin(f,x0,p,num)
% 求步長
% for i=1:paraNum
% syms(['x',num2str(i)]);
% end
temp = sym('x',[1,num]);
f1=f(x0 - temp.*p');
for i = 1:num
temp(i) = diff(f1,temp(i));
end
jieGuo = solve(temp);
jieGuo = struct2cell(jieGuo);
x = zeros(1,num);
for i = 1:num
x(i) = double(jieGuo{i});
end
end