深度學習中dropout策略的理解


現在有空整理一下關於深度學習中怎么加入dropout方法來防止測試過程的過擬合現象。

首先了解一下dropout的實現原理:

這些理論的解釋在百度上有很多。。。。

這里重點記錄一下怎么實現這一技術

參考別人的博客,主要http://www.cnblogs.com/dupuleng/articles/4340293.html

講解一下用Matlab中的深度學習工具箱怎么實現dropout

首先要載入工具包。DeepLearn Toolbox是一個非常有用的matlab deep learning工具包,下載地址:https://github.com/rasmusbergpalm/DeepLearnToolbox

要使用它首先要將該工具包添加到matlab的搜索路徑中,

1、將包復制到matlab 的toolbox中,作者的路徑是D:\program Files\matlab\toolbox\

2、在matlab的命令行中輸入:  

cd D:\program Files\matlab\toolbox\deepLearnToolbox\
addpath(gepath('D:\program Files\matlab\toolbox\deepLearnToolbox-master\')
savepath   %保存,這樣就不需要每次都添加一次

3、驗證添加是否成功,在命令行中輸入  

which saesetup

果成功就會出現,saesetup.m的路徑D:\program Files\matlab\toolbox\deepLearnToolbox-master\SAE\saesetup.m 

4、使用deepLearnToolbox 工具包,做一個簡單的demo,將autoencoder模型使用dropout前后的結果進行比較。

load mnist_uint8;
train_x = double(train_x(1:2000,:)) / 255;
test_x  = double(test_x(1:1000,:))  / 255;
train_y = double(train_y(1:2000,:));
test_y  = double(test_y(1:1000,:));

%% //實驗一without dropout
rand('state',0)
sae = saesetup([784 100]);
sae.ae{1}.activation_function  = 'sigm';
sae.ae{1}.learningRate         =  1;
opts.numepochs = 10;
opts.batchsize = 100;
sae = saetrain(sae , train_x , opts );
visualize(sae.ae{1}.W{1}(:,2:end)');

nn = nnsetup([784 100 10]);% //初步構造了一個輸入-隱含-輸出層網絡,其中包括了
                           % //權值的初始化,學習率,momentum,激發函數類型,
                           % //懲罰系數,dropout等

nn.W{1} = sae.ae{1}.W{1};
opts.numepochs =  10;   %  //Number of full sweeps through data
opts.batchsize = 100;  %  //Take a mean gradient step over this many samples
[nn, ~] = nntrain(nn, train_x, train_y, opts);
[er, ~] = nntest(nn, test_x, test_y);
str = sprintf('testing error rate is: %f',er);
fprintf(str);

%% //實驗二:with dropout
rand('state',0)
sae = saesetup([784 100]);
sae.ae{1}.activation_function  = 'sigm';
sae.ae{1}.learningRate         =  1;

opts.numepochs = 10;
opts.bachsize = 100;
sae = saetrain(sae , train_x , opts );
figure;
visualize(sae.ae{1}.W{1}(:,2:end)');

nn = nnsetup([784 100 10]);% //初步構造了一個輸入-隱含-輸出層網絡,其中包括了
                           % //權值的初始化,學習率,momentum,激發函數類型,
                           % //懲罰系數,dropout等
nn.dropoutFraction = 0.5;  
nn.W{1} = sae.ae{1}.W{1};
opts.numepochs =  10;   %  //Number of full sweeps through data
opts.batchsize = 100;  %  //Take a mean gradient step over this many samples
[nn, L] = nntrain(nn, train_x, train_y, opts);
[er, bad] = nntest(nn, test_x, test_y);
str = sprintf('testing error rate is: %f',er);
fprintf(str);

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM