無論是深度學習還是機器學習,大多情況下訓練中都會遇到這幾個參數,今天依據我自己的理解具體的總結一下,可能會存在錯誤,還請指正.
learning_rate , weight_decay , momentum這三個參數的含義. 並附上demo.
我們會使用一個例子來說明一下:
比如我們有一堆數據
,我們只知道這對數據是從一個
黑盒中得到的,我們現在要尋找到那個具體的函數f(x),我們定義為目標函數T.


我們現在假定有存在這個函數並且這個函數為:

我們現在要使用這對數據來訓練目標函數. 我們可以設想如果存在一個這個函數,必定滿足{x,y}所有的關系,也就是說:

那么最理想的情況下 :
,那么我們不妨定義這樣一個優化目標函數:


對於這堆數據,我們認為當Loss(W)對於所有的pair{x,y}都滿足 Loss(W)趨近於或者等於0時,我們認為我們找到這個理想的目標函數T. 也就是此時
.

以上,我們發現尋找的目標函數的問題,已經成功的轉移為求解:

也就是Loss 越小, f(x)越接近我們尋找的目標函數T.
那么說了這么多,這個和我們說的學習率learning_rate有什么關系呢?
既然我們知道了我們當前的f(x)和目標函數的T的誤差,那么我們可以將這個誤差轉移到每一個參數上,也就是變成每一個參數w和目標函數T的參數w_t的誤差. 然后我們就以一定的幅度stride來縮小和真實值的距離,我們稱這個stride為學習率learning_rate 而且我們就是這么做的.
我們用公式表述就是:
我們的誤差(損失)Loss:

我們這一個凸函數. 我們先對這個函數進行各個分量求偏導.

對於w0的偏導數:

那么對於分量w0承擔的誤差為:

那么我們需要使我們當前的w0更加接近目標函數的T的w0_t參數.我們需要做運算:

來更新wo的值. 同理其他參數w,而這個學習率就是來控制我們每次靠近真實值的幅度,為什么要這么做呢?
因為我們表述的誤差只是一種空間表述形式我們可以使用均方差也可以使用絕對值,還可以使用對數,以及交叉熵等等,所以只能大致的反映,並不精確,就想我們問路一樣,別人告訴我們直走五分鍾,有的人走的快,有的人走的慢,所以如果走的快的話,當再次問路的時候,就會發現走多了,而折回來,這就是我們訓練過程中的loss曲線震盪嚴重的原因之一. 所以學習率要設置在合理的大小.
好了說了這么多,這是學習率. 那么什么是權重衰減weight_decay呢? 有什么作用呢?
我們接着看上面的這個Loss(w),我們發現如果參數過多的話,對於高位的w3,我們對其求偏導:

我們發現w3開始大於1的時候,w3會調節的很快,幅度很大,從而使得特征x3變為異常敏感.從而出現過擬合(overfitting).
這個時候,我們需要約束一下w2,w3等高階參數的大小,於是我們對Loss增加一個懲罰項,使得Loss的正反方向,不應該只由f(x) -y 決定,而還應該加上一個
;於是Loss變成了:


我們繼續對Loss求解偏導數:
對wo求偏導:

對w3求偏導:

我們發現當x3值過大時,會改變Loss的導數的方向.而來抑制w2,w3等高階函數的繼續增長. 然而這樣抑制並不是很靈活,所以我們在前面加入一個系數
,這個系數在數學上稱之為拉格朗日乘子系數,也就是我們用到的weight_decay. 這樣我們可以通過調節weight_decay系數,來調節w3,w2等高階的增長程度。加入weight_decay后的公式:


從公式可以看出 ,weight_decay越大,抑制越大,w2,w3等系數越小,weight_decay越小,抑制越小,w2,w3等系數越大
那么沖量momentum又是啥?
我們在使用梯度下降法,來調整w時公式是這樣的:

我們每一次都是計算當前的梯度:

這樣會發現對於那些梯度比較小的地方,參數w更新的幅度比較小,訓練變得漫長,或者收斂慢.有時候遇到非最優的凸點,會出現沖不出去的現象.
而沖量加進來是一種快速效應.借助上一次的勢能來和當前的梯度來調節當前的參數w.
公式表達為:

這樣可以有效的避免掉入凸點無法沖出來,而且收斂速度也快很多.
附上demo: 使用mxnet編碼.

1 // 2 // Created by xijun1 on 2017/12/14. 3 // 4 5 #include <iostream> 6 #include <vector> 7 #include <string> 8 #include <mxnet/mxnet-cpp/MxNetCpp.h> 9 #include <mxnet/mxnet-cpp/op.h> 10 11 namespace mlp{ 12 class MlpNet{ 13 public: 14 static mx_float OutputAccuracy(mx_float* pred, mx_float* target) { 15 int right = 0; 16 for (int i = 0; i < 128; ++i) { 17 float mx_p = pred[i * 10 + 0]; 18 float p_y = 0; 19 for (int j = 0; j < 10; ++j) { 20 if (pred[i * 10 + j] > mx_p) { 21 mx_p = pred[i * 10 + j]; 22 p_y = j; 23 } 24 } 25 if (p_y == target[i]) right++; 26 } 27 return right / 128.0; 28 } 29 30 static void net(){ 31 using mxnet::cpp::Symbol; 32 using mxnet::cpp::NDArray; 33 34 Symbol x = Symbol::Variable("X"); 35 Symbol y = Symbol::Variable("label"); 36 37 std::vector<std::int32_t> shapes({512 , 10}); 38 //定義一個兩層的網絡. wx + b 39 Symbol weight_0 = Symbol::Variable("weight_0"); 40 Symbol biases_0 = Symbol::Variable("biases_0"); 41 42 Symbol fc_0 = mxnet::cpp::FullyConnected("fc_0",x,weight_0,biases_0 43 ,512); 44 45 Symbol output_0 = mxnet::cpp::LeakyReLU("relu_0",fc_0,mxnet::cpp::LeakyReLUActType::kLeaky); 46 47 Symbol weight_1 = Symbol::Variable("weight_1"); 48 Symbol biases_1 = Symbol::Variable("biases_1"); 49 Symbol fc_1 = mxnet::cpp::FullyConnected("fc_1",output_0,weight_1,biases_1,10); 50 Symbol output_1 = mxnet::cpp::LeakyReLU("relu_1",fc_1,mxnet::cpp::LeakyReLUActType::kLeaky); 51 Symbol pred = mxnet::cpp::SoftmaxOutput("softmax",output_1,y); //目標函數,loss函數 52 mxnet::cpp::Context ctx = mxnet::cpp::Context::cpu( 0); 53 54 //定義輸入數據 55 std::shared_ptr< mx_float > aptr_x(new mx_float[128*28] , [](mx_float* aptr_x){ delete [] aptr_x ;}); 56 std::shared_ptr< mx_float > aptr_y(new mx_float[128] , [](mx_float * aptr_y){ delete [] aptr_y ;}); 57 58 //初始化數據 59 for(int i=0 ; i<128 ; i++){ 60 for(int j=0;j<28 ; j++){ 61 //定義x 62 aptr_x.get()[i*28+j]= i % 10 +0.1f; 63 } 64 65 //定義y 66 aptr_y.get()[i]= i % 10; 67 } 68 std::map<std::string, mxnet::cpp::NDArray> args_map; 69 //導入數據 70 NDArray arr_x(mxnet::cpp::Shape(128,28),ctx, false); 71 NDArray arr_y(mxnet::cpp::Shape( 128 ),ctx,false); 72 //將數據轉換到NDArray中 73 arr_x.SyncCopyFromCPU(aptr_x.get(),128*28); 74 arr_x.WaitToRead(); 75 76 arr_y.SyncCopyFromCPU(aptr_y.get(),128); 77 arr_y.WaitToRead(); 78 79 args_map["X"]=arr_x.Slice(0,128).Copy(ctx) ; 80 args_map["label"]=arr_y.Slice(0,128).Copy(ctx); 81 NDArray::WaitAll(); 82 //綁定網絡 83 mxnet::cpp::Executor *executor = pred.SimpleBind(ctx,args_map); 84 //選擇優化器 85 mxnet::cpp::Optimizer *opt = mxnet::cpp::OptimizerRegistry::Find("sgd"); 86 mx_float learning_rate = 0.0001; //學習率 87 mx_float weight_decay = 1e-4; //權重 88 opt->SetParam("momentum", 0.9) 89 ->SetParam("lr", learning_rate) 90 ->SetParam("wd", weight_decay); 91 //定義各個層參數的數組 92 NDArray arr_w_0(mxnet::cpp::Shape(512,28),ctx, false); 93 NDArray arr_b_0(mxnet::cpp::Shape( 512 ),ctx,false); 94 NDArray arr_w_1(mxnet::cpp::Shape(10 , 512 ) , ctx , false); 95 NDArray arr_b_1(mxnet::cpp::Shape( 10 ) , ctx , false); 96 97 //初始化權重參數 98 arr_w_0 = 0.01f; 99 arr_b_1 = 0.01f; 100 arr_w_1 = 0.01f; 101 arr_b_1 = 0.01f; 102 103 //初始化參數 104 executor->arg_dict()["weight_0"]=arr_w_0; 105 executor->arg_dict()["biases_0"]=arr_b_0; 106 executor->arg_dict()["weight_1"]=arr_w_1; 107 executor->arg_dict()["biases_1"]=arr_b_1; 108 109 mxnet::cpp::NDArray::WaitAll(); 110 //訓練 111 std::cout<<" Training "<<std::endl; 112 113 int max_iters = 20000; //最大迭代次數 114 //獲取訓練網絡的參數列表 115 std::vector<std::string> args_name = pred.ListArguments(); 116 for (int iter = 0; iter < max_iters ; ++iter) { 117 executor->Forward(true); 118 executor->Backward(); 119 120 if(iter % 100 == 0){ 121 std::vector<NDArray> & out = executor->outputs; 122 std::shared_ptr<mx_float> tp_x( new mx_float[128*28] , 123 [](mx_float * tp_x){ delete [] tp_x ;}); 124 out[0].SyncCopyToCPU(tp_x.get(),128*10); 125 NDArray::WaitAll(); 126 std::cout<<"epoch "<<iter<<" "<<"Accuracy: "<< OutputAccuracy(tp_x.get() , aptr_y.get())<<std::endl; 127 } 128 //args_name. 129 for(size_t arg_ind=0; arg_ind<args_name.size(); ++arg_ind){ 130 //執行 131 if(args_name[arg_ind]=="X" || args_name[arg_ind]=="label") 132 continue; 133 134 opt->Update(arg_ind,executor->arg_arrays[arg_ind],executor->grad_arrays[arg_ind]); 135 } 136 NDArray::WaitAll(); 137 138 } 139 140 141 } 142 }; 143 } 144 145 int main(int argc , char * argv[]){ 146 mlp::MlpNet::net(); 147 MXNotifyShutdown(); 148 return EXIT_SUCCESS; 149 }
結果:
Training epoch 0 Accuracy: 0.09375 epoch 100 Accuracy: 0.304688 epoch 200 Accuracy: 0.195312 epoch 300 Accuracy: 0.203125 epoch 400 Accuracy: 0.304688 epoch 500 Accuracy: 0.296875 epoch 600 Accuracy: 0.304688 epoch 700 Accuracy: 0.304688 epoch 800 Accuracy: 0.398438 epoch 900 Accuracy: 0.5 epoch 1000 Accuracy: 0.5 epoch 1100 Accuracy: 0.40625 epoch 1200 Accuracy: 0.5 epoch 1300 Accuracy: 0.398438 epoch 1400 Accuracy: 0.40625 epoch 1500 Accuracy: 0.703125 epoch 1600 Accuracy: 0.609375 epoch 1700 Accuracy: 0.507812 epoch 1800 Accuracy: 0.703125 epoch 1900 Accuracy: 0.703125 epoch 2000 Accuracy: 0.804688 epoch 2100 Accuracy: 0.703125 epoch 2200 Accuracy: 0.804688 epoch 2300 Accuracy: 0.804688 epoch 2400 Accuracy: 0.804688 epoch 2500 Accuracy: 0.90625 epoch 2600 Accuracy: 0.90625 epoch 2700 Accuracy: 0.90625 epoch 2800 Accuracy: 1 epoch 2900 Accuracy: 1