MATLAB實例:BP神經網絡用於回歸(非線性擬合)任務
作者:凱魯嘎吉 - 博客園 http://www.cnblogs.com/kailugaji/
問題描述
給定多元(多維)數據X,有真實結果Y,對這些數據進行擬合(回歸),得到擬合函數的參數,進而得到擬合函數,現在進來一些新樣本,對這些新樣本進行預測出相應地Y值。通常的最小二乘法進行線性擬合並不適用於所有數據,對於大多數數據而言,他們的擬合函數是非線性的,人為構造擬合函數相當困難,沒有一定的經驗積累很難完美的構造出符合條件的擬合函數。因此神經網絡在這里被應用來做回歸(擬合)任務,進一步用來預測。神經網絡是很強大的擬合工具,雖然數學可解釋性差,但擬合效果好,因而得到廣泛應用。BP神經網絡是最基礎的網絡結構,輸入層,隱層,輸出層,三層結構。如下圖所示。
整體的目標函數就是均方誤差
$L=||f(X)-Y||_{2}^{2}$
其中(激活函數可以自行設定)
$f(X)=purelin\left( {{W}_{2}}\cdot \tan sig({{W}_{1}}\cdot X+{{b}_{1}})+{{b}_{2}} \right)$
$N$: 輸入數據的個數
$D$: 輸入數據的維度
${{D}_{1}}$: 隱層節點的個數
$X$: 輸入數據($D$*$N$)
$Y$: 真實輸出(1*$N$)
${{W}_{1}}$: 輸入層到隱層的權值(${{D}_{1}}$*$D$)
${{b}_{1}}$: 隱層的偏置(${{D}_{1}}$*1)
${{W}_{2}}$: 輸入層到隱層的權值(1*${{D}_{1}}$)
${{b}_{2}}$: 隱層的偏置(1*1)
通過給定訓練數據與訓練標簽來訓練網絡的權值與偏置,進一步得到擬合函數$f(X)$。這樣,來了新數據后,直接將新數據X代入函數$f(X)$,即可得到預測的結果。
y = tansig(x) = 2/(1+exp(-2*x))-1;
y = purelin(x) = x;
MATLAB程序
用到的數據為UCI數據庫的housing數據:https://archive.ics.uci.edu/ml/machine-learning-databases/housing/
輸入數據,最后一列是真實的輸出結果,將數據打亂順序,95%的作為訓練集,剩下的作為測試集。這里隱層節點數為20。
BP_kailugaji.m
function errorsum=BP_kailugaji(data_load, NodeNum, ratio) % Author:凱魯嘎吉 https://www.cnblogs.com/kailugaji/ % Input: % data_load: 最后一列真實輸出結果 % NodeNum: 隱層節點個數 % ratio: 訓練集占總體樣本的比率 [Num, ~]=size(data_load); data=data_load(:, 1:end-1); real_label=data_load(:, end); k=rand(1,Num); [~,n]=sort(k); kk=floor(Num*ratio); %找出訓練數據和預測數據 input_train=data(n(1:kk),:)'; output_train=real_label(n(1:kk))'; input_test=data(n(kk+1:Num),:)'; output_test=real_label(n(kk+1:Num))'; %選連樣本輸入輸出數據歸一化 [inputn,inputps]=mapminmax(input_train); [outputn,outputps]=mapminmax(output_train); %% BP網絡訓練 % %初始化網絡結構 net=newff(inputn, outputn, NodeNum); net.trainParam.epochs=100; % 最大迭代次數 net.trainParam.lr=0.01; % 步長 net.trainParam.goal=1e-5; % 迭代終止條件 % net.divideFcn = ''; %網絡訓練 net=train(net,inputn,outputn); W1=net.iw{1, 1}; b1=net.b{1}; W2=net.lw{2, 1}; b2=net.b{2}; fun1=net.layers{1}.transferFcn; fun2=net.layers{2}.transferFcn; %% BP網絡預測 %預測數據歸一化 inputn_test=mapminmax('apply',input_test,inputps); %網絡預測輸出 an=sim(net,inputn_test); %網絡輸出反歸一化 BPoutput=mapminmax('reverse',an,outputps); %% 結果分析 figure(1) plot(BPoutput,'-.or') hold on plot(output_test,'-*b'); legend('預測輸出','期望輸出') xlim([1 (Num-kk)]); title('BP網絡預測輸出','fontsize',12) ylabel('函數輸出','fontsize',12) xlabel('樣本','fontsize',12) saveas(gcf,sprintf('BP網絡預測輸出.jpg'),'bmp'); %預測誤差 error=BPoutput-output_test; errorsum=sum(mse(error)); % 保留參數 save BP_parameter W1 b1 W2 b2 fun1 fun2 net inputps outputps
demo.m
clear;clc;close all data_load=dlmread('housing.data'); NodeNum=20; ratio=0.95; errorsum=BP_kailugaji(data_load, NodeNum, ratio); fprintf('測試集總體均方誤差為:%f\n', errorsum); %% % 驗證原來的或者預測新的數據 num=1; % 驗證第num行數據 load('BP_parameter.mat'); data=data_load(:, 1:end-1); real_label=data_load(:, end); X=data(num, :); X=X'; Y=real_label(num, :); %% BP網絡預測 %預測數據歸一化 X=mapminmax('apply',X,inputps); %網絡預測輸出 Y_pre=sim(net,X); %網絡輸出反歸一化 Y_pre=mapminmax('reverse',Y_pre,outputps); error=Y_pre-Y'; errorsum=sum(mse(error)); fprintf('第%d行數據的均方誤差為:%f\n', num, errorsum);
結果
測試集總體均方誤差為:5.184424 第1行數據的均方誤差為:3.258243
注意:隱層節點個數,激活函數,迭代終止條件等等參數需要根據具體數據進行調整。