數據庫:MNIST http://yann.lecun.com/exdb/mnist/
編寫分類器程序,要求:
1)選用課上講述過的分類器
2)使用交叉驗證法生成訓練集及測試集,並以此為基礎評價模型的泛化誤差。
3)總結影響分類器算法結果的因素。
第一步:利用matlab對MNIST數據進行讀取。
-------------------------------------------------------------*****************************************************---------------------------------------------------------------------------
插入:關於測試集和訓練集
訓練集、驗證集和測試集這三個名詞在機器學習領域極其常見
訓練集
作用:估計模型
學習樣本數據集,通過匹配一些參數來建立一個分類器。建立一種分類的方式,主要是用來訓練模型的。
測試集
作用:檢驗最終選擇最優的模型的性能如何
主要是測試訓練好的模型的分辨能力(識別率等)
為何需要划分
簡而言之,為了防止過度擬合。如果我們把所有數據都用來訓練模型的話,建立的模型自然是最契合這些數據的,測試表現也好。但換了其它數據集測試這個模型效果可能就沒那么好了。總而言之訓練集和測試集相同的話,模型評估結果可能比實際要好。
總結
顯然,training set是用來訓練模型或確定模型參數的,如ANN中權值等; validation set是用來做模型選擇(model selection),即做模型的最終優化及確定的,如ANN的結構;而 test set則純粹是為了測試已經訓練好的模型的推廣能力。當然,test set這並不能保證模型的正確性,他只是說相似的數據用此模型會得出相似的結果。但實際應用中,一般只將數據集分成兩類,即training set 和test set,大多數文章並不涉及validation set。
一個典型的划分是訓練集占總樣本的50%,而其它各占25%,三部分都是從樣本中隨機抽取。
----------------------------------------------------------------------------------********************************************************---------------------------------------------------------------------
mnist數據集原網站下載下來的格式無法直接打開,需要對mnist數據集下載之后解壓,
在 Windows 平台下解壓這些文件時,操作系統會自動修改這些文件的文件名,比如會將倒數第二個短線-修改為.,也即 train-images-idx3-ubyte.gz 解壓為train-images.idx3-ubyte(文件類型就自作主張地變成了idx3-ubyte)。.matlab要想對這些數據進行使用,就需要編寫程序對數據進行讀取。由網上查詢matlab讀取mnist數據集的程序如下:
clear variables
close all
clc
train_x_file=char('train-images.idx3-ubyte');%得到vector形式
test_x_file=char('t10k-images.idx3-ubyte');%得到vector形式
train_y_file=char('train-labels.idx1-ubyte');%得到vector形式
test_y_file=char('t10k-labels.idx1-ubyte');%得到vector形式
train_x=decodefile(train_x_file,'image');
test_x=decodefile(test_x_file,'image');
train_labels=decodefile(train_y_file,'label');
test_labels=decodefile(test_y_file,'label');
其中,decodefile.m函數為:
%MNIST源文件下載地址http://yann.lecun.com/exdb/mnist/index.html
%功能:將下載得到的二進制文件轉換為10進制數據,提取像素數據和標簽數據
%適用:僅適用於MNIST數據集,修改后可適用於其他
function output=decodefile(filename,type)
%數據介紹如下,參考網址http://yann.lecun.com/exdb/mnist/index.html
、fio=fopen(filename,'r');%原始文件中數據是以2進制存儲的。
a = fread(fio,'uint8');%以8進制方式讀取源文件。雖然前幾項是32bit的,但是圖像像素數據是8bit的,所以此處用8bit處理。
if strcmp(type,'image')
output=a(17:end);%提取像素數據
else if strcmp(type,'label')
output=a(9:end);
end
end
第二步:分類器分類
按照北京理工大學高琪老師等講授的《人工智能之模式識別》,一般分類器的實現分為三個步驟:分類器模型的建立、模型的訓練、以及模板分類結果的預測。分別利用templateSVM、 fitcecoc、predict函數實現功能。
補充:
關於交叉驗證法,此處采用折十交叉驗證法,對訓練集數據再次進行划分,分為十份,每次留一份作為驗證集,其余九份為訓練集,十次循環訓練模型。最后用訓練后的模型再次對測試集數據進行驗證。
完整程序如下所示:
% svm.m
clear variables
close all
clc
train_x_file=char('train-images.idx3-ubyte');%得到vector形式
test_x_file=char('t10k-images.idx3-ubyte');%得到vector形式
train_y_file=char('train-labels.idx1-ubyte');%得到vector形式
test_y_file=char('t10k-labels.idx1-ubyte');%得到vector形式
train_x=decodefile(train_x_file,'image');
test_x=decodefile(test_x_file,'image');
train_labels=decodefile(train_y_file,'label');
test_labels=decodefile(test_y_file,'label');
% 如果想檢驗轉化是否正確,可執行以下代碼。
train_images=reshape(train_x,28,28,60000);%reshape后的圖像是放倒的
train_images=permute(train_images,[2 1 3]);%對每張圖像進行行列的轉置處理
test_images=reshape(test_x,28,28,10000);%reshape后的圖像是放倒的
test_images=permute(test_images,[2 1 3]);%對每張圖像進行行列的轉置處理
train_labels=train_labels';
test_labels=test_labels';
%選取部分數據進行訓練
train_num = 500;
test_num = 200;
data_train = mat2vector(train_images(:,:,1:train_num),train_num);%圖像轉向量
data_test = mat2vector(test_images(:,:,1:test_num),test_num);%mnist數據集圖像為28*28
train_labels=train_labels(:,1:train_num)';
test_labels=test_labels(:,1:test_num)';
%定義SVM分類器模板,采用線性核函數, 這里選用最簡單的線性模型做演示;
t = templateSVM('KernelFunction','linear');
%交叉驗證法
[m,n] = size(data_train);
indices = crossvalind('Kfold', m, 10);
for i = 1 : 10
% 獲取第i份測試數據的索引邏輯值
test = (indices == i);
% 取反,獲取第i份訓練數據的索引邏輯值
train = ~test;
%1份測試,9份訓練
test_data = data_train(test,:);
test_label = train_labels(test,:);
train_data = data_train(train, :);
train_label = train_labels(train, :);
% 使用數據的代碼
svm_model = fitcecoc(train_data,train_label,'Learners',t);%訓練模型,由於是多分類,不能直接調用fitcsvm
end
%不使用交叉驗證法時訓練模型
% svm_model = fitcecoc(data_train,train_labels(1:train_num),'Learners',t);
%利用測試集數據測試結果
result = predict(svm_model,data_test);
result = result.';
fprintf('預測結果:');
result(1:20)%取20個打印出來對比
fprintf('真實分布:');
test_labels(1:20)'
acc = 0.;
for i = 1:test_num
if result(i)==test_labels(i)
acc = acc+1;
end
end
fprintf('精確度為:%5.2f%%\n',(acc/test_num)*100);
% mat2vector.m
% 輸入:圖片數據(矩陣),樣本個數
% 函數作用:將圖片組轉化為行向量的組合,每個行向量作為一張圖片的特征
% 輸出:樣本數*圖片像素數量大小的矩陣
function [data_]= mat2vector(data,num)
[row,col,~] = size(data);
data_ = zeros(num,row*col);
for page = 1:num
for rows = 1:row
for cols = 1:col
data_(page,((rows-1)*col+cols)) = im2double(data(rows,cols,page));
end
end
end
end
未使用交叉驗證法時:
使用交叉驗證法時: