Gradient Descent
機器學習中很多模型的參數估計都要用到優化算法,梯度下降是其中最簡單也用得最多的優化算法之一。梯度下降(Gradient Descent)[3]也被稱之為最快梯度(Steepest Descent),可用於尋找函數的局部最小值。梯度下降的思路為,函數值在梯度反方向下降是最快的,只要沿着函數的梯度反方向移動足夠小的距離到一個新的點,那么函數值必定是非遞增的,如圖1所示。
梯度下降思想的數學表述如下:
其中f(x)f(x)為存在下界的可導函數。根據該思路,如果我們從x0x0為出發點,每次沿着當前函數梯度反方向移動一定距離αkαk,得到序列x0,x1,⋯,xnx0,x1,⋯,xn:
對應的各點函數值序列之間的關系為:
很顯然,當nn達到一定值時,函數f(x)f(x)是會收斂到局部最小值的。算法1簡單描述了一般化的梯度優化方法。在算法1中,我們需要選擇一個搜索方向dkdk滿足以下關系:
當dk=−∇f(x)dk=−∇f(x)時f(x)f(x)下降最快,但是只要滿足∇f(xk)Tdk<0∇f(xk)Tdk<0的dkdk都可以作為搜素方向。一般搜索方向表述為如下形式:
其中BkBk為正定矩陣。當Bk=IBk=I時對應最快梯度下降算法;當Bk=H(xk)−1Bk=H(xk)−1時對應牛頓法,如果H(xk)=∇2f(xk)H(xk)=∇2f(xk)為正定矩陣。 在迭代過程中用於更新xkxk的步長αkαk可以是常數也可以是變化的。如果αkαk足夠小,收斂是可以得到保證的,但這意味這迭代次數nn要很大時函數才會收斂(圖2(a));如果αkαk比較大,更新后的點很可能越過局部最優解(圖2(b))。有什么方法可以幫助我們自動確定最優步長呢?下面要說的線性搜索就包含一組解決方案。
Line Search
在給定搜索方向dkdk的前提下,線性搜索要解決的問題如下:
如果h(α)h(α)是可微的凸函數,我們能通過解析解直接求得上式最優的步長;但非線性的優化問題需要通過迭代形式求得近似的最優步長。對於上式,局部或全局最優解對應的導數為h′(α)=∇f(xk+αdk)Tdk=0h′(α)=∇f(xk+αdk)Tdk=0。因為dkdk與f(xk)f(xk)在xkxk處的梯度方向夾角大於90度,因此h′(0)≤0h′(0)≤0,如果能找到α^α^使得h′(α^)>0h′(α^)>0,那么必定存在α⋆∈[0,α^)α⋆∈[0,α^)使得h′(α⋆)=0h′(α⋆)=0。有多種迭代算法可以求得α⋆α⋆的近似值,下面選擇幾種典型的介紹。
Bisection Search
二分線性搜索(Bisection Line Search)[2]可用於求解函數的根,其思想很簡單,就是不斷將現有區間划分為兩半,選擇必定含有使h′(α)=0h′(α)=0的半個區間作為下次迭代的區間,直到尋得h′(α⋆)≈0h′(α⋆)≈0為止,算法描述見2。二分線性搜素可以確保h(α)h(α)是收斂的,只要h(α)h(α)在區間(0,α^)(0,α^)上是連續的且h′(0)h′(0)和h(α^)h(α^)異號。經歷nn次迭代后,當前區間[αl,αh][αl,αh]的長度為:
由迭代的終止條件之一αh−αl≥ϵαh−αl≥ϵ知迭代次數的上界為:
下面給出二分搜索的Python代碼
1 def bisection(dfun,theta,args,d,low,high,maxiter=1e4): 2 """ 3 #Functionality:find the root of the function(fun) in the interval [low,high] 4 #@Parameters 5 #dfun:compute the graident of function f(x) 6 #theta:Parameters of the model 7 #args:other variables needed to compute the value of dfun 8 #[low,high]:the interval which contains the root 9 #maxiter:the max number of iterations 10 """ 11 eps=1e-6 12 val_low=np.sum(dfun(theta+low*d,args)*d.T) 13 val_high=np.sum(dfun(theta+high*d,args)*d.T) 14 if val_low*val_high>0: 15 raise Exception('Invalid interval!') 16 iter_num=1 17 while iter_num<maxiter: 18 mid=(low+high)/2 19 val_mid=np.sum(dfun(theta+mid*d,args)*d.T) 20 if abs(val_mid)<eps or abs(high-low)<eps: 21 return mid 22 elif val_mid*val_low>0: 23 low=mid 24 else: 25 high=mid 26 iter_num+=1
Backtracking
回溯線性搜索(Backing Line Search)[1]基於Armijo准則計算搜素方向上的最大步長,其基本思想是沿着搜索方向移動一個較大的步長估計值,然后以迭代形式不斷縮減步長,直到該步長使得函數值f(xk+αdk)f(xk+αdk)相對與當前函數值f(xk)f(xk)的減小程度大於期望值(滿足Armijo准則)為止。Armijo准則(見圖3)的數學描述如下:
其中f:Rn→Rf:Rn→R,c1∈(0,1)c1∈(0,1),αα為步長,dk∈Rndk∈Rn為滿足f′(xk)Tdk<0f′(xk)Tdk<0的搜索方向。但是僅憑Armijo准則不足以求得較好的步長,根據前面的梯度下降的知識可知,只要αα足夠小就能滿足Armijo准則。因此常用的策略就是從較大的步長開始,然后以τ∈(0,1)τ∈(0,1)的速度縮短步長,直到滿足Armijo准則為止,這樣選出來的步長不至於太小,對應的算法描述見3。前面介紹的二分線性搜索的目標是求得滿足h′(α)≈0h′(α)≈0的最優步長近似值,而回溯線性搜索放松了對步長的約束,只要步長能使函數值有足夠大的變化即可。前者可以少計算幾次搜索方向,但在計算最優步長上花費了不少代價;后者退而求其次,找到一個差不多的步長即可,那么代價就是要多計算幾次搜索方向。
接下來,我們要證明回溯線性搜索在Armijo准則下的收斂性問題[6]。因為h′(0)=f′(xk)Tdk<0h′(0)=f′(xk)Tdk<0,且0<c1<10<c1<1,則有
根據導數的基本定義,結合上式,有如下關系:
因此,存在一個步長α^>0α^>0,對任意的α∈(0,α^)α∈(0,α^),下式均成立
即∀α∈(0,α^),f(xk+αdk)<f(xk)+cαf′(xk)Tdk∀α∈(0,α^),f(xk+αdk)<f(xk)+cαf′(xk)Tdk。 下面給出基於Armijo准則的線性搜索Python代碼:
1 def ArmijoBacktrack(fun,dfun,theta,args,d,stepsize=1,tau=0.5,c1=1e-3): 2 """ 3 #Functionality:find an acceptable stepsize via backtrack under Armijo rule 4 #@Parameters 5 #fun:compute the value of objective function 6 #dfun:compute the gradient of objective function 7 #theta:a vector of parameters of the model 8 #stepsize:initial step size 9 #c1:sufficient decrease Parameters 10 #tau:rate of shrink of stepsize 11 """ 12 slope=np.sum(dfun(theta,args)*d.T) 13 obj_old=costFunction(theta,args) 14 theta_new=theta+stepsize*d 15 obj_new=costFunction(theta_new,args) 16 while obj_new>obj_old+c1*stepsize*slope: 17 stepsize*=tau 18 theta_new=theta+stepsize*d 19 obj_new=costFunction(theta_new,args) 20 return stepsize
Interpolation
基於Armijo准則的回溯線性搜索的收斂速度無法得到保證,特別是要回退很多次后才能落入滿足Armijo准則的區間。如果我們根據已有的函數值和導數信息,采用多項式插值法(Interpolation)[12,6,5,9]擬合函數,然后根據該多項式函數估計函數的極值點,這樣選擇合適步長的效率會高很多。 假設我們只有xkxk處的函數值f(xk)f(xk)及其倒數f′(xk)f′(xk),且第一次嘗試的步長為α0α0。如果α0α0不滿足條件,那么我們根據這些信息可以構造一個二次近似函數hq(α)hq(α)
注意,該二次函數滿足hq(0)=h(0)hq(0)=h(0),h′q(0)=h′(0)hq′(0)=h′(0)和hq(α0)=h(α0)hq(α0)=h(α0),如圖4(a)所示。接下來,根據hq(α)hq(α)的最小值估計下一個步長:
如果α1α1仍然不滿足條件,我們可以繼續重復上述過程,直到得到的步長滿足條件為止。假設我們在整個線性搜索過程中都用二次插值函數,那么最好有c1∈(0,0.5]c1∈(0,0.5],為什么呢?簡單證明一下:如果α0α0不滿足Armijo准則,那么必定存在比α0α0小的步長滿足該准則,所以利用二次插值函數估算的步長α1<α0α1<α0才合理。結合α0α0不滿足Armijo准則和α1<α0α1<α0,可知c1≤0.5c1≤0.5。 如果我們已經嘗試了多個步長,卻每次只用上一次步長的相關信息構造二次函數,未免是對計算資源的浪費,其實我們可以利用多個步長的信息構造信息量更大更准確的插值函數的。在計算導數的代價大於計算函數值的代價時,應盡量避免計算h′(α)h′(α),下面給出一個三次插值函數hc(α)hc(α),如圖4(b)所示
其中
對hc(α)hc(α)求導,可得極值點αi+1∈[0,αi]αi+1∈[0,αi]的形式如下:
利用以上的三次插值函數求解下一個步長的過程不斷重復,直到步長滿足條件為止。如果出現a=0a=0的情況,三次插值函數退化為二次插值函數,在實現該算法時需要注意這點。在此過程中,如果αiαi太小或αi−1αi−1與αiαi太接近,需要重置αi=αi−1/2αi=αi−1/2,該保護措施(safeguards)保證下一次的步長不至於太小[6,5]。為什么會有這個作用呢?1)因為αi+1∈[0,αi]αi+1∈[0,αi],所以當αiαi很小時αi+1αi+1也很小;2)當αi−1αi−1與αiαi太靠近時有a≈b≈∞a≈b≈∞,根據αi+1αi+1的表達式可知αi+1≈0αi+1≈0。 但是,在很多情況下,計算函數值后只需付出較小的代價就能順帶計算出導數值或其近似值,這使得我們可以用更精確的三次Hermite多項式[6]進行插值,如圖4(c)所示