Matlab中常用的分類器有隨機森林分類器、支持向量機(SVM)、K近鄰分類器、朴素貝葉斯、集成學習方法和鑒別分析分類器等。各分類器的相關Matlab函數使用方法如下:
首先對以下介紹中所用到的一些變量做統一的說明:
train_data——訓練樣本,矩陣的每一行數據構成一個樣本,每列表示一種特征
train_label——訓練樣本標簽,為列向量
test_data——測試樣本,矩陣的每一行數據構成一個樣本,每列表示一種特征
test_label——測試樣本標簽,為列向量
①隨機森林分類器(Random Forest)
TB=TreeBagger(nTree,train_data,train_label);
predict_label=predict(TB,test_data);
②支持向量機(Support Vector Machine,SVM)
SVMmodel=svmtrain(train_data,train_label);
predict_label=svmclassify(SVMmodel,test_data);
③K近鄰分類器(KNN)
KNNmodel=ClassificationKNN.fit(train_data,train_label,'NumNeighbors',1);
predict_label=predict(KNNmodel,test_data);
④朴素貝葉斯(Naive Bayes)
Bayesmodel=NaiveBayes.fit(train_data,train_label);
predict_label=predict(Bayesmodel,test_data);
⑤集成學習方法(Ensembles for Boosting)
Bmodel=fitensemble(train_data,train_label,'AdaBoostM1',100,'tree','type','classification');
predict_label=predict(Bmodel,test_data);
⑥鑒別分析分類器(Discriminant Analysis Classifier)
DACmodel=ClassificationDiscriminant.fit(train_data,train_label);
predict_label=predict(DACmodel,test_data);
具體使用如下:(練習數據下載地址如下http://en.wikipedia.org/wiki/Iris_flower_data_set,簡單介紹一下該數據集:有一批花可以分為3個品種,不同品種的花的花萼長度、花萼寬度、花瓣長度、花瓣寬度會有差異,根據這些特征實現品種分類)
%% 隨機森林分類器(Random Forest)
nTree=10;
B=TreeBagger(nTree,train_data,train_label,'Method', 'classification');
predictl=predict(B,test_data);
predict_label=str2num(cell2mat(predictl));
Forest_accuracy=length(find(predict_label == test_label))/length(test_label)*100;
%% 支持向量機
% SVMStruct = svmtrain(train_data, train_label);
% predictl=svmclassify(SVMStruct,test_data);
% predict_label=str2num(cell2mat(predictl));
% SVM_accuracy=length(find(predict_label == test_label))/length(test_label)*100;
%% K近鄰分類器(KNN)
% mdl = ClassificationKNN.fit(train_data,train_label,'NumNeighbors',1);
% predict_label=predict(mdl, test_data);
% KNN_accuracy=length(find(predict_label == test_label))/length(test_label)*100
%% 朴素貝葉斯 (Naive Bayes)
% nb = NaiveBayes.fit(train_data, train_label);
% predict_label=predict(nb, test_data);
% Bayes_accuracy=length(find(predict_label == test_label))/length(test_label)*100;
%% 集成學習方法(Ensembles for Boosting, Bagging, or Random Subspace)
% ens = fitensemble(train_data,train_label,'AdaBoostM1' ,100,'tree','type','classification');
% predictl=predict(ens,test_data);
% predict_label=str2num(cell2mat(predictl));
% EB_accuracy=length(find(predict_label == test_label))/length(test_label)*100;
%% 鑒別分析分類器(discriminant analysis classifier)
% obj = ClassificationDiscriminant.fit(train_data, train_label);
% predictl=predict(obj,test_data);
% predict_label=str2num(cell2mat(predictl));
% DAC_accuracy=length(find(predict_label == test_label))/length(test_label)*100;
%% 練習
% meas=[0 0;2 0;2 2;0 2;4 4;6 4;6 6;4 6];
% [N n]=size(meas);
% species={'1';'1';'1';'1';'-1';'-1';'-1';'-1'};
% ObjBayes=NaiveBayes.fit(meas,species);
% x=[3 3;5 5];
% result=ObjBayes.predict(x);
參考鏈接:https://blog.csdn.net/jisuanjiguoba/java/article/details/80004568