訓練數據來源:http://archive.ics.uci.edu/ml/machine-learning-databases/balance-scale/balance-scale.data
數據簡介
本訓練數據共有625個訓練樣例,每個樣例有4個屬性x1,x2,x3,x4,每個屬性值可以取值{1,2,3,4,5}。
數據集中的每個樣例都有標簽"L","B"或"R"。
我們在這里序號末尾為1的樣本當作測試集,共有63個,其他的作為訓練集,共有562個。
下面我們使用朴素貝葉斯算法來進行訓練。
第一步,實現類的標簽"L","B","R"轉換成數字1,2,3
matlab代碼如下:
clear;
clc;
ex=importdata('balance-scale.data.txt'); %讀入文件
X=ex.data;
m=size(ex.textdata); %數據大小
Y=zeros(m);
for i=1:m
if strcmp(ex.textdata(i),'L')==1
Y(i)=1;
elseif strcmp(ex.textdata(i),'B')==1
Y(i)=2;
else Y(i)=3;
end
end
第二步,參數估計,詳情參見鏈接生成學習算法之朴素貝葉斯算法
matlab代碼如下:(注意運行一下程序之前,先把上一步我們得到的X,Y數據load到內存里)
%朴素貝葉斯算法實現分類問題(三類y=1,y=2,y=3)
%我們把所有數字序號末尾為1的留作測試集,其他未訓練集
m=625; %樣本總數
m1=562; %訓練集樣本數量
m2=63; %測試集樣本數量
%三類樣本數量分別記為count1,count2,count3
count1=0;
count2=0;
count3=0;
%count_1(i,j)表示在第一類(y=1)的情況下,第i個屬性是j的樣本個數
count_1=zeros(4,5);
%count_2(i,j)表示在第二類(y=2)的情況下,第i個屬性是j的樣本個數
count_2=zeros(4,5);
%count_3(i,j)表示在第三類(y=3)的情況下,第i個屬性是j的樣本個數
count_3=zeros(4,5);
ii=1;%用來標識測試集的序號
for i=1:m
if mod(i,10)==1
XX(ii,:)=X(i,:);
YY(ii)=Y(i);
ii=ii+1;
else
x=X(i,:);
if Y(i)==1
count1=count1+1;
for j=1:4 %指示第j個屬性
for k=1:5 %第j個屬性為哪個值
if x(j)==k
count_1(j,k)=count_1(j,k)+1;
break;
end
end
end
elseif Y(i)==2
count2=count2+1;
for j=1:4 %指示第j個屬性
for k=1:5 %第j個屬性為哪個值
if x(j)==k
count_2(j,k)=count_2(j,k)+1;
break;
end
end
end
else count3=count3+1;
for j=1:4 %指示第j個屬性
for k=1:5 %第j個屬性為哪個值
if x(j)==k
count_3(j,k)=count_3(j,k)+1;
break;
end
end
end
end
end
%分別計算三類概率y1=p(y=1)、y2=p(y=2)、y3=p(y=3)的估計值
y1=count1/m1;
y2=count2/m1;
y3=count3/m1;
%y_1(i,j)表示在第一類(y=1)的情況下,第i個屬性取值為j的概率估計值
%y_2(i,j)表示在第二類(y=2)的情況下,第i個屬性取值為j的概率估計值
%y_3(i,j)表示在第三類(y=3)的情況下,第i個屬性取值為j的概率估計值
for i=1:4
for j=1:5
y_1(i,j)=count_1(i,j)/count1;
y_2(i,j)=count_2(i,j)/count2;
y_3(i,j)=count_3(i,j)/count3;
end
end
end
%做預測,p1,p2,p3分別表示輸入值xx為第1,2,3類的概率
cc=0; %用cc表示正確分類的樣本
for i=1:m2
xx=XX(i,:);
yy=YY(i);
p1=y1*y_1(1,xx(1))*y_1(2,xx(2))*y_1(3,xx(3))*y_1(4,xx(4));
p2=y2*y_2(1,xx(1))*y_2(2,xx(2))*y_2(3,xx(3))*y_2(4,xx(4));
p3=y3*y_3(1,xx(1))*y_3(2,xx(2))*y_3(3,xx(3))*y_3(4,xx(4));
%下面分別輸出三類的概率
%ans1=p1/(p1+p2+p3)
%ans2=p2/(p1+p2+p3)
%ans3=p3/(p1+p2+p3)
if p1>p2&&p1>p3
if yy==1 cc=cc+1;
end
end
if p2>p1&&p2>p3
if yy==2 cc=cc+1;
end
end
if p3>p1&&p3>p2
if yy==3 cc=cc+1;
end
end
end
%拿訓練集做測試集,得到的正確率
cc/m2
轉載自https://blog.csdn.net/zhulf0804/article/details/52424809