Adam (1) - Python實現


  • 算法特征
    ①. 梯度凸組合控制迭代方向; ②. 梯度平方凸組合控制迭代步長; ③. 各優化變量自適應搜索.
  • 算法推導
    Part Ⅰ 算法細節
    擬設目標函數符號為$J$, 則梯度表示如下,
    \begin{equation}
    g = \nabla J
    \label{eq_1}
    \end{equation}
    參考Momentum Gradient, 對梯度凸組合控制迭代方向first momentum,
    \begin{equation}
    m_{k} = \beta_1m_{k-1} + (1 - \beta_1)g_{k}
    \label{eq_2}
    \end{equation}
    其中, $\beta_1$是凸組合系數, 也是指數衰減率.
    參考RMSProp, 對梯度平方凸組合控制迭代步長second raw momentum,
    \begin{equation}
    v_{k} = \beta_2v_{k-1} + (1 - \beta_2)g_{k}\odot g_{k}
    \label{eq_3}
    \end{equation}
    其中, $\beta_2$是凸組合系數, 也是指數衰減率.
    由於first momentum與second raw momentum均初始化為0, 分別以如下方式修正以降低凸組合系數對初始迭代的影響,
    \begin{gather}
    \hat{m}_{k} = \frac{m_{k}}{1 - \beta_1^{k}}\label{eq_4} \\
    \hat{v}_{k} = \frac{v_{k}}{1 - \beta_2^{k}}\label{eq_5}
    \end{gather}
    不失一般性, 令第$k$步迭代形式如下,
    \begin{equation}
    x_{k+1} = x_k + \alpha_kd_k
    \label{eq_6}
    \end{equation}
    其中, $\alpha_k$、$d_k$分別代表第$k$步迭代步長與迭代方向, 且
    \begin{gather}
    \alpha_k = \frac{\alpha}{\sqrt{\hat{v}_k} + \epsilon}\label{eq_7} \\
    d_k = -\hat{m}_k\label{eq_8}
    \end{gather}
    其中, $\alpha$代表步長參數, $\epsilon$取值足夠小正數避免迭代步長分母為0.
    Part Ⅱ 算法流程
    初始化步長參數$\alpha$、足夠小正數$\epsilon$、指數衰減率$\beta_1$、指數衰減率$\beta_2$
    初始化收斂判據$\zeta$、迭代起點$x_1$
    計算當前梯度值$g_1=\nabla J(x_1)$, 令: 一階矩$m_0 = 0$, 二階矩$v_0 = 0$, $k = 1$, 重復以下步驟,
      step1: 如果$\|g_k\| < \zeta$, 收斂, 迭代停止
      step2: 更新一階矩$m_k = \beta_1m_{k-1} + (1 - \beta_1)g_{k}$
      step3: 更新二階矩$v_k = \beta_2v_{k-1} + (1 - \beta_2)g_{k}\odot g_{k}$
      step4: 計算一階矩修正$\displaystyle \hat{m}_{k} = \frac{m_{k}}{1 - \beta_1^{k}}$
      step5: 計算二階矩修正$\displaystyle \hat{v}_{k} = \frac{v_{k}}{1 - \beta_2^{k}}$
      step6: 計算迭代步長$\displaystyle \alpha_k = \frac{\alpha}{\sqrt{\hat{v}_k} + \epsilon}$
      step7: 計算迭代方向$d_k = -\hat{m}_k$
      step8: 更新迭代點$x_{k+1} = x_k + \alpha_kd_k$
      step9: 更新梯度值$g_{k+1}=\nabla J(x_{k+1})$
      step10: 令$k = k+1$, 轉step1
  • 代碼實現
    現以如下無約束凸優化問題為例進行算法實施,
    \begin{equation*}
    \min\quad 5x_1^2 + 2x_2^2 + 3x_1 - 10x_2 + 4
    \end{equation*}
    Adam實現如下,
      1 # Adam之實現
      2 
      3 import numpy
      4 from matplotlib import pyplot as plt
      5 
      6 
      7 # 目標函數0階信息
      8 def func(X):
      9     funcVal = 5 * X[0, 0] ** 2 + 2 * X[1, 0] ** 2 + 3 * X[0, 0] - 10 * X[1, 0] + 4
     10     return funcVal
     11     
     12     
     13 # 目標函數1階信息
     14 def grad(X):
     15     grad_x1 = 10 * X[0, 0] + 3
     16     grad_x2 = 4 * X[1, 0] - 10
     17     gradVec = numpy.array([[grad_x1], [grad_x2]])
     18     return gradVec
     19     
     20     
     21 # 定義迭代起點
     22 def seed(n=2):
     23     seedVec = numpy.random.uniform(-100, 100, (n, 1))
     24     return seedVec
     25     
     26     
     27 class Adam(object):
     28     
     29     def __init__(self, _func, _grad, _seed):
     30         '''
     31         _func: 待優化目標函數
     32         _grad: 待優化目標函數之梯度
     33         _seed: 迭代起始點
     34         '''
     35         self.__func = _func
     36         self.__grad = _grad
     37         self.__seed = _seed
     38         
     39         self.__xPath = list()
     40         self.__JPath = list()
     41         
     42         
     43     def get_solu(self, alpha=0.001, beta1=0.9, beta2=0.999, epsilon=1.e-8, zeta=1.e-6, maxIter=3000000):
     44         '''
     45         獲取數值解,
     46         alpha: 步長參數
     47         beta1: 一階矩指數衰減率
     48         beta2: 二階矩指數衰減率
     49         epsilon: 足夠小正數
     50         zeta: 收斂判據
     51         maxIter: 最大迭代次數
     52         '''
     53         self.__init_path()
     54         
     55         x = self.__init_x()
     56         JVal = self.__calc_JVal(x)
     57         self.__add_path(x, JVal)
     58         grad = self.__calc_grad(x)
     59         m, v = numpy.zeros(x.shape), numpy.zeros(x.shape)
     60         for k in range(1, maxIter + 1):
     61             # print("k: {:3d},   JVal: {}".format(k, JVal))
     62             if self.__converged1(grad, zeta):
     63                 self.__print_MSG(x, JVal, k)
     64                 return x, JVal, True
     65             
     66             m = beta1 * m + (1 - beta1) * grad
     67             v = beta2 * v + (1 - beta2) * grad * grad
     68             m_ = m / (1 - beta1 ** k)
     69             v_ = v / (1 - beta2 ** k)
     70             
     71             alpha_ = alpha / (numpy.sqrt(v_) + epsilon)
     72             d = -m_
     73             xNew = x + alpha_ * d
     74             JNew = self.__calc_JVal(xNew)
     75             self.__add_path(xNew, JNew)
     76             if self.__converged2(xNew - x, JNew - JVal, zeta ** 2):
     77                 self.__print_MSG(xNew, JNew, k + 1)
     78                 return xNew, JNew, True
     79                 
     80             gNew = self.__calc_grad(xNew)
     81             x, JVal, grad = xNew, JNew, gNew
     82         else:
     83             if self.__converged1(grad, zeta):
     84                 self.__print_MSG(x, JVal, maxIter)
     85                 return x, JVal, True
     86                 
     87         print("Adam not converged after {} steps!".format(maxIter))
     88         return x, JVal, False
     89         
     90         
     91     def get_path(self):
     92         return self.__xPath, self.__JPath
     93             
     94             
     95     def __converged1(self, grad, epsilon):
     96         if numpy.linalg.norm(grad, ord=numpy.inf) < epsilon:
     97             return True
     98         return False
     99         
    100         
    101     def __converged2(self, xDelta, JDelta, epsilon):
    102         val1 = numpy.linalg.norm(xDelta, ord=numpy.inf)
    103         val2 = numpy.abs(JDelta)
    104         if val1 < epsilon or val2 < epsilon:
    105             return True
    106         return False
    107         
    108         
    109     def __print_MSG(self, x, JVal, iterCnt):
    110         print("Iteration steps: {}".format(iterCnt))
    111         print("Solution:\n{}".format(x.flatten()))
    112         print("JVal: {}".format(JVal))
    113         
    114         
    115     def __calc_JVal(self, x):
    116         return self.__func(x)
    117         
    118         
    119     def __calc_grad(self, x):
    120         return self.__grad(x)
    121         
    122         
    123     def __init_x(self):
    124         return self.__seed
    125         
    126         
    127     def __init_path(self):
    128         self.__xPath.clear()
    129         self.__JPath.clear()
    130         
    131         
    132     def __add_path(self, x, JVal):
    133         self.__xPath.append(x)
    134         self.__JPath.append(JVal)
    135         
    136                 
    137 class AdamPlot(object):
    138     
    139     @staticmethod
    140     def plot_fig(adamObj):
    141         x, JVal, tab = adamObj.get_solu(0.1)
    142         xPath, JPath = adamObj.get_path()
    143         
    144         fig = plt.figure(figsize=(10, 4))
    145         ax1 = plt.subplot(1, 2, 1)
    146         ax2 = plt.subplot(1, 2, 2)
    147         
    148         ax1.plot(numpy.arange(len(JPath)), JPath, "k.", markersize=1)
    149         ax1.plot(0, JPath[0], "go", label="starting point")
    150         ax1.plot(len(JPath)-1, JPath[-1], "r*", label="solution")
    151         
    152         ax1.legend()
    153         ax1.set(xlabel="$iterCnt$", ylabel="$JVal$")
    154         
    155         x1 = numpy.linspace(-100, 100, 300)
    156         x2 = numpy.linspace(-100, 100, 300)
    157         x1, x2 = numpy.meshgrid(x1, x2)
    158         f = numpy.zeros(x1.shape)
    159         for i in range(x1.shape[0]):
    160             for j in range(x1.shape[1]):
    161                 f[i, j] = func(numpy.array([[x1[i, j]], [x2[i, j]]]))
    162         ax2.contour(x1, x2, f, levels=36)
    163         x1Path = list(item[0] for item in xPath)
    164         x2Path = list(item[1] for item in xPath)
    165         ax2.plot(x1Path, x2Path, "k--", lw=2)
    166         ax2.plot(x1Path[0], x2Path[0], "go", label="starting point")
    167         ax2.plot(x1Path[-1], x2Path[-1], "r*", label="solution")
    168         ax2.set(xlabel="$x_1$", ylabel="$x_2$")
    169         ax2.legend()
    170                 
    171         fig.tight_layout()
    172         # plt.show()
    173         fig.savefig("plot_fig.png")
    174 
    175         
    176         
    177 if __name__ == "__main__":
    178     adamObj = Adam(func, grad, seed())
    179     
    180     AdamPlot.plot_fig(adamObj)
    View Code
  • 結果展示
  • 使用建議
    ①. 局部二階矩求和一定程度上反應了局部的曲率信息, 用以近似並替代Hessian矩陣是合理的;
    ②. 文獻中初始化參數推薦$\alpha=0.001, \beta_1=0.9, \beta_2=0.999, \epsilon=10^{-8}$, 實際根據需要優先調整步長參數$\alpha$.
  • 參考文檔
    Kingma D P, Ba J. Adam: A method for stochastic optimization[J]. arXiv preprint arXiv:1412.6980, 2014.


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM