RBF網絡的matlab實現


一、用工具箱實現函數擬合

參考:http://blog.csdn.net/zb1165048017/article/details/49407075

(1)newrb()

該函數可以用來設計一個近似徑向基網絡(approximate RBF)。調用格式為:

[net,tr]=newrb(P,T,GOAL,SPREAD,MN,DF)

其中P為Q組輸入向量組成的R*Q位矩陣,T為Q組目標分類向量組成的S*Q維矩陣。GOAL為均方誤差目標(Mean Squard Error Goal),默認為0.0;SPREAD為徑向基函數的擴展速度,默認為1;MN為神經元的最大數目,默認為Q;DF維兩次顯示之間所添加的神經元數目,默認為25;ner為返回值,一個RBF網絡,tr為返回值,訓練記錄。

用newrb()創建RBF網絡是一個不斷嘗試的過程(從程序的運行可以看出來),在創建過程中,需要不斷增加中間層神經元的和個數,知道網絡的輸出誤差滿足預先設定的值為止。

(2)newrbe()

該函數用於設計一個精確徑向基網絡(exact RBF),調用格式為:

net=newrbe(P,T,SPREAD)

其中P為Q組輸入向量組成的R*Q維矩陣,T為Q組目標分類向量組成的S*Q維矩陣;SPREAD為徑向基函數的擴展速度,默認為1

和newrb()不同的是,newrbe()能夠基於設計向量快速,無誤差地設計一個徑向基網絡。

(3)radbas()

該函數為徑向基傳遞函數,調用格式為

A=radbas(N)

info=radbas(code)

其中N為輸入(列)向量的S*Q維矩陣,A為函數返回矩陣,與N一一對應,即N的每個元素通過徑向基函數得到A;info=radbas(code)表示根據code值的不同返回有關函數的不同信息。包括

derive——返回導函數的名稱

name——返回函數全稱

output——返回輸入范圍

active——返回可用輸入范圍

使用exact徑向基網絡來實現非線性的函數回歸:

%%清空環境變量  
clc
clear
%%產生輸入輸出數據  
%設置步長  
interval=0.01;
%產生x1,x2  
x1=-1.5:interval:1.5;
x2=-1.5:interval:1.5;
%按照函數先求的響應的函數值,作為網絡的輸出  
F=20+x1.^2-10*cos(2*pi*x1)+x2.^2-10*cos(2*pi*x2);
%%網絡建立和訓練  
%網絡建立,輸入為[x1;x2],輸出為F。spread使用默認  
net=newrbe([x1;x2],F);
%%網絡的效果驗證  
%將原數據回帶,測試網絡效果  
ty=sim(net,[x1;x2]);
%%使用圖像來看網絡對非線性函數的擬合效果  
figure
plot3(x1,x2,F,'rd');
hold on;
plot3(x1,x2,ty,'b-.');
view(113,36);
title('可視化的方法觀察嚴格的RBF神經網絡的擬合效果');
xlabel('x1')
ylabel('x2')
zlabel('F')
grid on

  結果:

二、自編函數實現擬合

clear;
%X=1:100;
X=[-4*pi:0.07*pi:8*pi];  
P=length(X);
Y=[];
M=10; 
centers=[];
deltas=[]; 
weights=[];
set = {}; 
gap=0.1; 
%**************************************************************************  
%構造訓練樣本X,Y  
X=[-4*pi:0.07*pi:8*pi];  
for i=1:P  
    Y(i)=sin(X(i));
end  
%**************************************************************************  
for i=1:M                          %先隨意初始化M個中心點
    centers(i)= X( i*floor( P/10 ) );  
end  
done=0;  
while(~done)  
    for i=1:M  
       set{i}=[];  
    end  
    for i=1:P  
        distance=100;
        for j=1:M  
            curr=abs(X(i)-centers(j));  
            if curr<distance  
                sets=j;  
                distance=curr;  
            end  
        end  
        set{sets}=[set{sets},X(i)];        %聚類,找出M個中心點,並且樣本分布在這十個點周圍
    end 
    for i=1:M  
        new_centers(i)=sum(set{i})/length(set{i}); %重新計算中心點:M個類里每個類的中心點  
    end  
    done=0;  
     for i=1:M  
          sum1(i)=abs(centers(i)-new_centers(i));
     end  
     if sum(sum1)>gap  
            done=0;      %不斷循環,直到找到最佳的中心點;
            centers=new_centers;  
     else  
            done=1;  
     end     
end

for i=1:M
    curr=abs( centers-centers(i));  
    [curr_2,b]=min(curr);  
    curr(b)=100;  
    curr_2=min(curr);  
    deltas(i)=1*curr_2;  
end
%{
for i=1:M
    sum=0;
    num=length(set{i});
    for j=1:num
        sum=sum+(set{i}(j)-centers(i))^2;
    end
    deltas(i)=(sum)^0.5/num;
end 
%}
for i=1:P  
    for j=1:M  
        curr=abs(X(i)-centers(j));  
        K(i,j)=exp( -curr^2/(2*deltas(j)^2) );  %隱含層的輸出
    end  
end  
%計算權值矩陣  
weights=inv(K'*K)*K'*Y';  
%**************************************************************************  
%測試計算出函數的情況  
x_test=[-4*pi:0.07*pi:8*pi];    
for i=1:length(x_test)  
    sum=0;  
    for j=1:M  
        curr=weights(j)*exp(-abs(x_test(i)-centers(j))^2/(2*deltas(j)^2));  
        sum=sum+curr;  
    end  
    y_test(i)=sum;  
end  
figure(1)  
scatter(X,Y,'k+');  
hold on;  
plot(x_test,y_test,'r.-')

  

結果:

三、工具箱函數的RBF分類

train_data=LowDimFaces(1:10,:);  %train_data是一個10*20維的矩陣,其中行表示樣本數,列數表示特征個數
train_label=[ones(1,5),zeros(1,5)]; %行向量

display('讀入測試數據...');

test_data=LowDimFaces(201:210,:);
test_label=[ones(1,5),zeros(1,5)];

[train_data,minX,maxX] = premnmx(train_data);
test_data = tramnmx(test_data,minX,maxX) ;

%網絡建立,輸入為[x1;x2],輸出為F。spread使用默認  
net=newrbe(train_data',train_label);
%%網絡的效果驗證  
%將原數據回帶,測試網絡效果  
ty=sim(net,test_data');
%%使用圖像來看網絡對非線性函數的擬合效果  
Y=[];
hitnum=0;
for i=1:10
    if ty(i)>0.5
        Y(i)=1;
    else
        Y(i)=0;
    end
    if Y(i)==test_label(i)
        hitnum=hitnum+1;
    end
end
fprintf('訓練集中結果的正確率是%f%%\n',100*hitnum/10);

  

 

四、自編函數實現RBF分類

clear;
clc;
M=10; 
centers=[;];
deltas=[]; 
weights=[];
set = {}; 
gap=0.1; 
%**************************************************************************  
XA=ones(1,500);
YA=ones(1,500);  %初始化A類的輸入數據
XB=ones(1,500);
YB=ones(1,500);  %初始化B類的輸入數據
for i=1:500
    XA(i)=cos(2*pi*(i+8)/25-0.25*pi)*(i+8)/25;
    YA(i)=sin(2*pi*(i+8)/25-0.25*pi)*(i+8)/25-0.25;
    XB(i)=sin(2*pi*(i+8)/25+0.25*pi)*(i+8)/-25;
    YB(i)=cos(2*pi*(i+8)/25+0.25*pi)*(i+8)/25-0.25;
end
scatter(XA,YA,20,'b');
hold on;
scatter(XB,YB,20,'k');
hold off;
X1=cat(1,XA,YA);
X2=cat(1,XB,YB);
X=cat(2,X1,X2);  %得到訓練數據集X,Y
Y=zeros(1,1000);
Y(1,1:500)=1;
k=rand(1,1000);
[m,n]=sort(k);
X=X(:,n(1:1000));
Y=Y(:,n(1:1000));
%**************************************************************************  
[X,minX,maxX] = premnmx(X);
P=length(X);
for i=1:M                          %先隨意初始化M個中心點
    centers(:,i)= X(:,i*floor( P/10 ) );  
end  
done=0;  
while(~done)  
    for i=1:M  
       set{i}=[;];  
    end  
    for i=1:P  
        distance=100;
        for j=1:M  
            curr=norm(X(:,i)-centers(:,j));  
            if curr<distance  
                sets=j;  
                distance=curr;  
            end  
        end  
        set{sets}=[set{sets},X(:,i)];        %聚類,找出M個中心點,並且樣本分布在這十個點周圍
    end 
    for i=1:M  
        new_centers(:,i)=sum(set{i}')'/length(set{i}); %重新計算中心點:M個類里每個類的中心點  
    end  
    done=0;  
     for i=1:M  
          sum1(i)=norm(centers(:,i)-new_centers(:,i));
     end  
     if sum(sum1)>gap  
            done=0;      %不斷循環,直到找到最佳的中心點;
            centers=new_centers;  
     else  
            done=1;  
     end     
end

for i=1:M
    curr=[;];
    curr=abs( bsxfun(@minus,centers,centers(:,i)));
    k=100;
    m=norm(curr(:,j));
    for j=1:M
        if m<k && m~=0
            k=m;
        end
    end
    deltas(i)=k; 
end


for i=1:P  
    for j=1:M  
        curr=norm(X(:,i)-centers(:,j));  
        K(i,j)=exp( -curr^2/(2*deltas(j)^2) );  %隱含層的輸出
    end  
end  
%計算權值矩陣  
weights=inv(K'*K)*K'*Y';  
%**************************************************************************  
%測試計算出函數的情況  
x_test=X;
for i=1:length(x_test)  
    sum=0;  
    for j=1:M  
        curr=weights(j)*exp(-norm(x_test(:,i)-centers(:,j))^2/(2*deltas(j)^2));  
        sum=sum+curr;  
    end  
    y_test(i)=sum;  
end 
y_test(find(y_test<0.5))=0;
y_test(find(y_test>=0.5))=1;
count=0;
for j=1:length(y_test)
    if y_test(j)==Y(j)
        count=count+1;
    end
end
fprintf('分類正確率為:%.2f%%',100*count/length(y_test));
    

  


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM