調參過程中的參數 學習率,權重衰減,沖量(learning_rate , weight_decay , momentum)


無論是深度學習還是機器學習,大多情況下訓練中都會遇到這幾個參數,今天依據我自己的理解具體的總結一下,可能會存在錯誤,還請指正.
learning_rate , weight_decay , momentum這三個參數的含義. 並附上demo.
 
我們會使用一個例子來說明一下:
            比如我們有一堆數據 ,我們只知道這對數據是從一個 黑盒中得到的,我們現在要尋找到那個具體的函數f(x),我們定義為目標函數T.
          我們現在假定有存在這個函數並且這個函數為:
                                        
         我們現在要使用這對數據來訓練目標函數. 我們可以設想如果存在一個這個函數,必定滿足{x,y}所有的關系,也就是說:     
                                        
         那么最理想的情況下 :    ,那么我們不妨定義這樣一個優化目標函數:
                                    
        對於這堆數據,我們認為當Loss(W)對於所有的pair{x,y}都滿足 Loss(W)趨近於或者等於0時,我們認為我們找到這個理想的目標函數T. 也就是此時  .
      以上,我們發現尋找的目標函數的問題,已經成功的轉移為求解: 
                                        
      也就是Loss 越小, f(x)越接近我們尋找的目標函數T.
那么說了這么多,這個和我們說的學習率learning_rate有什么關系呢?
                既然我們知道了我們當前的f(x)和目標函數的T的誤差,那么我們可以將這個誤差轉移到每一個參數上,也就是變成每一個參數w和目標函數T的參數w_t的誤差. 然后我們就以一定的幅度stride來縮小和真實值的距離,我們稱這個stride為學習率learning_rate 而且我們就是這么做的.
                我們用公式表述就是:
                        我們的誤差(損失)Loss:    
                                        
 
 
 
 
                我們這一個凸函數. 我們先對這個函數進行各個分量求偏導.
            
對於w0的偏導數:
                
那么對於分量w0承擔的誤差為:
                         並且這個誤差帶方向.
那么我們需要使我們當前的w0更加接近目標函數的T的w0_t參數.我們需要做運算:
                 (梯度下降算法)
來更新wo的值. 同理其他參數w,而這個學習率就是來控制我們每次靠近真實值的幅度,為什么要這么做呢?
因為我們表述的誤差只是一種空間表述形式我們可以使用均方差也可以使用絕對值,還可以使用對數,以及交叉熵等等,所以只能大致的反映,並不精確,就想我們問路一樣,別人告訴我們直走五分鍾,有的人走的快,有的人走的慢,所以如果走的快的話,當再次問路的時候,就會發現走多了,而折回來,這就是我們訓練過程中的loss曲線震盪嚴重的原因之一. 所以學習率要設置在合理的大小.

好了說了這么多,這是學習率. 那么什么是權重衰減weight_decay呢? 有什么作用呢?
          我們接着看上面的這個Loss(w),我們發現如果參數過多的話,對於高位的w3,我們對其求偏導:
            
我們發現w3開始大於1的時候,w3會調節的很快,幅度很大,從而使得特征x3變為異常敏感.從而出現過擬合(overfitting).
       這個時候,我們需要約束一下w2,w3等高階參數的大小,於是我們對Loss增加一個懲罰項,使得Loss的正反方向,不應該只由f(x) -y 決定,而還應該加上一個 ;於是Loss變成了:
   
我們繼續對Loss求解偏導數:
對wo求偏導:
                
 
 


對w3求偏導:
            
 
我們發現當x3值過大時,會改變Loss的導數的方向.而來抑制w2,w3等高階函數的繼續增長. 然而這樣抑制並不是很靈活,所以我們在前面加入一個系數 ,這個系數在數學上稱之為拉格朗日乘子系數,也就是我們用到的weight_decay. 這樣我們可以通過調節weight_decay系數,來調節w3,w2等高階的增長程度。加入weight_decay后的公式:
 
 
從公式可以看出 ,weight_decay越大,抑制越大,w2,w3等系數越小,weight_decay越小,抑制越小,w2,w3等系數越大


那么沖量momentum又是啥?
     我們在使用梯度下降法,來調整w時公式是這樣的:
        
我們每一次都是計算當前的梯度:
                
這樣會發現對於那些梯度比較小的地方,參數w更新的幅度比較小,訓練變得漫長,或者收斂慢.有時候遇到非最優的凸點,會出現沖不出去的現象.
而沖量加進來是一種快速效應.借助上一次的勢能來和當前的梯度來調節當前的參數w.
公式表達為:
            
這樣可以有效的避免掉入凸點無法沖出來,而且收斂速度也快很多.
 
附上demo: 使用mxnet編碼.
  1 //
  2 // Created by xijun1 on 2017/12/14.
  3 //
  4 
  5 #include <iostream>
  6 #include <vector>
  7 #include <string>
  8 #include <mxnet/mxnet-cpp/MxNetCpp.h>
  9 #include <mxnet/mxnet-cpp/op.h>
 10 
 11 namespace  mlp{
 12     class MlpNet{
 13     public:
 14         static mx_float OutputAccuracy(mx_float* pred, mx_float* target) {
 15             int right = 0;
 16             for (int i = 0; i < 128; ++i) {
 17                 float mx_p = pred[i * 10 + 0];
 18                 float p_y = 0;
 19                 for (int j = 0; j < 10; ++j) {
 20                     if (pred[i * 10 + j] > mx_p) {
 21                         mx_p = pred[i * 10 + j];
 22                         p_y = j;
 23                     }
 24                 }
 25                 if (p_y == target[i]) right++;
 26             }
 27             return right / 128.0;
 28         }
 29 
 30        static void net(){
 31             using mxnet::cpp::Symbol;
 32             using mxnet::cpp::NDArray;
 33 
 34             Symbol x = Symbol::Variable("X");
 35             Symbol y = Symbol::Variable("label");
 36 
 37             std::vector<std::int32_t> shapes({512 , 10});
 38             //定義一個兩層的網絡. wx + b
 39             Symbol weight_0 = Symbol::Variable("weight_0");
 40             Symbol biases_0 = Symbol::Variable("biases_0");
 41 
 42             Symbol fc_0 = mxnet::cpp::FullyConnected("fc_0",x,weight_0,biases_0
 43                     ,512);
 44 
 45             Symbol output_0 = mxnet::cpp::LeakyReLU("relu_0",fc_0,mxnet::cpp::LeakyReLUActType::kLeaky);
 46 
 47             Symbol weight_1 = Symbol::Variable("weight_1");
 48             Symbol biases_1 = Symbol::Variable("biases_1");
 49             Symbol fc_1 = mxnet::cpp::FullyConnected("fc_1",output_0,weight_1,biases_1,10);
 50             Symbol output_1 = mxnet::cpp::LeakyReLU("relu_1",fc_1,mxnet::cpp::LeakyReLUActType::kLeaky);
 51             Symbol pred = mxnet::cpp::SoftmaxOutput("softmax",output_1,y);  //目標函數,loss函數
 52             mxnet::cpp::Context ctx = mxnet::cpp::Context::cpu( 0);
 53 
 54             //定義輸入數據
 55             std::shared_ptr< mx_float > aptr_x(new mx_float[128*28] , [](mx_float* aptr_x){ delete [] aptr_x ;});
 56             std::shared_ptr< mx_float > aptr_y(new mx_float[128] , [](mx_float * aptr_y){ delete [] aptr_y ;});
 57 
 58             //初始化數據
 59             for(int i=0 ; i<128 ; i++){
 60                 for(int j=0;j<28 ; j++){
 61                     //定義x
 62                     aptr_x.get()[i*28+j]= i % 10 +0.1f;
 63                 }
 64 
 65                 //定義y
 66                 aptr_y.get()[i]= i % 10;
 67             }
 68            std::map<std::string, mxnet::cpp::NDArray> args_map;
 69            //導入數據
 70            NDArray arr_x(mxnet::cpp::Shape(128,28),ctx, false);
 71            NDArray arr_y(mxnet::cpp::Shape( 128 ),ctx,false);
 72            //將數據轉換到NDArray中
 73            arr_x.SyncCopyFromCPU(aptr_x.get(),128*28);
 74            arr_x.WaitToRead();
 75 
 76            arr_y.SyncCopyFromCPU(aptr_y.get(),128);
 77            arr_y.WaitToRead();
 78 
 79            args_map["X"]=arr_x.Slice(0,128).Copy(ctx) ;    
 80            args_map["label"]=arr_y.Slice(0,128).Copy(ctx);
 81            NDArray::WaitAll();
 82             //綁定網絡
 83            mxnet::cpp::Executor *executor = pred.SimpleBind(ctx,args_map);
 84             //選擇優化器
 85            mxnet::cpp::Optimizer *opt = mxnet::cpp::OptimizerRegistry::Find("sgd");
 86            mx_float learning_rate = 0.0001; //學習率
 87            mx_float weight_decay = 1e-4; //權重
 88            opt->SetParam("momentum", 0.9)
 89                    ->SetParam("lr", learning_rate)
 90                    ->SetParam("wd", weight_decay);
 91            //定義各個層參數的數組
 92            NDArray arr_w_0(mxnet::cpp::Shape(512,28),ctx, false);
 93            NDArray arr_b_0(mxnet::cpp::Shape( 512 ),ctx,false);
 94            NDArray arr_w_1(mxnet::cpp::Shape(10 , 512 ) , ctx , false);
 95            NDArray arr_b_1(mxnet::cpp::Shape( 10 ) , ctx , false);
 96 
 97            //初始化權重參數
 98            arr_w_0 = 0.01f;
 99            arr_b_1 = 0.01f;
100            arr_w_1 = 0.01f;
101            arr_b_1 = 0.01f;
102 
103             //初始化參數
104             executor->arg_dict()["weight_0"]=arr_w_0;
105             executor->arg_dict()["biases_0"]=arr_b_0;
106             executor->arg_dict()["weight_1"]=arr_w_1;
107             executor->arg_dict()["biases_1"]=arr_b_1;
108 
109             mxnet::cpp::NDArray::WaitAll();
110             //訓練
111             std::cout<<" Training "<<std::endl;
112 
113             int max_iters = 20000;  //最大迭代次數
114            //獲取訓練網絡的參數列表
115            std::vector<std::string>  args_name = pred.ListArguments();
116             for (int iter = 0; iter < max_iters ; ++iter) {
117                 executor->Forward(true);
118                 executor->Backward();
119 
120                 if(iter % 100 == 0){
121                     std::vector<NDArray> & out = executor->outputs;
122                     std::shared_ptr<mx_float> tp_x( new mx_float[128*28] ,
123                                                     [](mx_float * tp_x){ delete [] tp_x ;});
124                     out[0].SyncCopyToCPU(tp_x.get(),128*10);
125                     NDArray::WaitAll();
126                     std::cout<<"epoch "<<iter<<"  "<<"Accuracy: "<<  OutputAccuracy(tp_x.get() , aptr_y.get())<<std::endl;
127                 }
128                 //args_name.
129                 for(size_t arg_ind=0; arg_ind<args_name.size(); ++arg_ind){
130                     //執行
131                     if(args_name[arg_ind]=="X" || args_name[arg_ind]=="label")
132                         continue;
133 
134                     opt->Update(arg_ind,executor->arg_arrays[arg_ind],executor->grad_arrays[arg_ind]);
135                 }
136                 NDArray::WaitAll();
137 
138             }
139 
140 
141         }
142     };
143 }
144 
145 int main(int argc , char * argv[]){
146     mlp::MlpNet::net();
147     MXNotifyShutdown();
148     return EXIT_SUCCESS;
149 }
View Code

結果:

Training 
epoch 0  Accuracy: 0.09375
epoch 100  Accuracy: 0.304688
epoch 200  Accuracy: 0.195312
epoch 300  Accuracy: 0.203125
epoch 400  Accuracy: 0.304688
epoch 500  Accuracy: 0.296875
epoch 600  Accuracy: 0.304688
epoch 700  Accuracy: 0.304688
epoch 800  Accuracy: 0.398438
epoch 900  Accuracy: 0.5
epoch 1000  Accuracy: 0.5
epoch 1100  Accuracy: 0.40625
epoch 1200  Accuracy: 0.5
epoch 1300  Accuracy: 0.398438
epoch 1400  Accuracy: 0.40625
epoch 1500  Accuracy: 0.703125
epoch 1600  Accuracy: 0.609375
epoch 1700  Accuracy: 0.507812
epoch 1800  Accuracy: 0.703125
epoch 1900  Accuracy: 0.703125
epoch 2000  Accuracy: 0.804688
epoch 2100  Accuracy: 0.703125
epoch 2200  Accuracy: 0.804688
epoch 2300  Accuracy: 0.804688
epoch 2400  Accuracy: 0.804688
epoch 2500  Accuracy: 0.90625
epoch 2600  Accuracy: 0.90625
epoch 2700  Accuracy: 0.90625
epoch 2800  Accuracy: 1
epoch 2900  Accuracy: 1

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM