優化方法總結續篇:下降單純形法(downhill simplex) 及python示例代碼


下降單純形法(downhill simplex method)是一個廣泛使用的“derivative free”的優化算法。一般來說它的效率不高,但是文獻[1]提到“the downhill simplex method may frequently be the *best* method to use if the figure of merit is “get something working quickly” for a problem whose computational burden is small.”

單純形法的基本思路是在\(N\)維空間中,構造一個非退化的初始單純形,然后做一系列的幾何操作,如反射、擴展、收縮等,逐步往極值點移動該單純形。由於這些幾何操作的目的基本上都是讓單純形往極小值移動,所以叫下降單純形法。

假設待優化的函數為\(f(\mathbf{x})\),\(N\)維空間里的單純形\(Z\)的\(N+1\)個頂點按照函數值從小到大排列分別為\(\mathbf{x}_{0},\mathbf{x}_{2},\cdots,\mathbf{x}_{N}\),定義


\(\bar{\mathbf{x}}=\sum_{i=0}^{N-1}\mathbf{x}_{i}\)


為\(Z\)中除了頂點\(\mathbf{x}_{N}\)之外其余頂點的中心點。

連接\(\bar{\mathbf{x}}\)和\(\mathbf{x}_{N}\)的直線公式可以寫成:


\(\bar{\mathbf{x}}(t)=(1-t)\bar{\mathbf{x}}+t\mathbf{x}_{N}\)


下降單純形法就是從沿着直線$\bar{\mathbf{x}}(t)$方向的幾個特殊步長尋找\(\mathbf{x}_{N}\)的替代點,使該替代點處的函數值比\(\mathbf{x}_{N}\)更小,如果沒有找到這種替代點,那么就把除了\(\mathbf{x}_{0}\)點之外的其余點往\(\mathbf{x}_{0}\)靠攏。

假設是要求函數的極小值,可以把對應函數值越小的點認為越好,越大的點認為越差。下降單純形法每步迭代過程簡述如下:

  1. 首先計算最差點沿着直線$\bar{\mathbf{x}}(t)$關於平均點$\bar{\mathbf{x}}$的對稱點
  2. 如果對稱點介於最好次差點之間,那么就接受它(reflection);
  3. 如果對稱點比最好點還好,那么做沿該方向更大膽的嘗試,令\(t=-2\),如果新嘗試點比對稱點更好則接受新嘗試點(expand),否則接受當前對稱點(reflection);
  4. 如果對稱點介於次差點和最差點之間,那么沿該方向做更小心的嘗試,即令\(t=-0.5\),如果新嘗試點比對稱點更好則接受新嘗試點(outside contraction)
  5. 如果對稱點比最差點還差,那么沿反方向做嘗試,即令\(t=0.5\),如果新嘗試點比對稱點更好則接受新嘗試點(inside contraction)
  6. 如果4和5均失敗,即對稱點比次差還要差而且outside contraction與inside contraction均失敗,那么把最好點之外的其他點都朝最好點收縮(shrink)


上述過程如果用區間圖表示會更清晰,區間的三個分界點就是最好次差最差點。對應的偽代碼可以參考文獻 [3] 第9.5節。

上述迭代算法需要提供一個初始單純形,該單純形可以參考文獻[1]給出的方法得到。首先任選一點\(\mathbf{x}_{0}\),然后利用以下公式:


\(\mathbf{x}_{i}=\mathbf{x}_{0}+\lambda\mathbf{e}_{i}\)


其中\(\mathbf{e}_{i}\)代表\(N\)維空間的單位矢量,\(\lambda\)表示步長。


另外,在實際應用中還有一個很重要的訣竅,即重啟(restart)。因為單純形在迭代更新的時候很容易就卡在某個中間位置上,這時單純形的 *最好* 和 *最差* 點幾乎相同,單純形的體積收縮的很小,會大大減慢迭代速度。為了解決這個問題,可以合理設置初始單純形的大小。更有效的,就是可以在單純形卡住的時候通過重新初始化單純形來加快收斂速度。在利用初始化公式的時候,把當前單純形的 *最好* 點作為\(\mathbf{x}_{0}\)保留下來,這樣保證重啟就不會影響之前已經計算的結果。


下面給出了單純形法求解一個簡單問題的python實現,其中待優化函數為\(f(\mathbf{x})=(\mathbf{x}-\mathbf{x}_{0})(\mathbf{x}-\mathbf{x}_{1})\)

首先定義一些需要用到的函數

import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns

def vertice_init(vertex_0, step_length):
    '''
    initialize vertice of the simplex
    using the following formula:
    $xi=x0+step_length*ei$
    '''

    emat = np.eye(vertex_0.size) * step_length
    vertice = [vertex_0]
    for ii in range(vertex_0.size):
        vertice.append(vertex_0 + emat[:, ii])
    return vertice


def f(v):
    '''
    Evaluation of Function $f$
    '''
    dim = v.size
    v0 = np.ones(dim) * 5
    v1 = np.ones(dim) * 3
    return 0.5 * np.dot(v - v0, v - v1)


def line(t, v1, v2):
    return (1 - t) * v1 + t * v2

接下來定義算法主函數,注意里面的restart部分:

def simplex(f, vertice, maxit=1000, step_length=100, tol=1e-3):
    vertice_max_list = []  # store the max vertex during each iteration
    vertice_min_list = []  # store the min vertex during each iteration
    for jj in range(maxit):
        y = []
        for ii in vertice:
            y.append(f(ii))
        y = np.array(y)
        #  only the highest (worst), next-highest, and lowest (best) vertice
        # are neeed
        idx = np.argsort(y)
        vertice_max_list.append(vertice[idx[-1]])
        vertice_min_list.append(vertice[idx[0]])
        
        # centroid of the best n vertice
        # NOTE: the worst vertex should be excluded, but for simplicity we don't do this
        v_mean = np.mean(vertice)

        # compute the candidate vertex and corresponding function vaule
        v_ref = line(-1, v_mean, vertice[idx[-1]])
        y_ref = f(v_ref)
        if y_ref >= y[idx[0]] and y_ref < y[idx[-2]]:
            # y_0<=y_ref<y_n, reflection (replace v_n+1 with v_ref)
            vertice[idx[-1]] = v_ref
            # print('reflection1')
        elif y_ref < y[idx[0]]:
            # y_ref<y_0, expand
            v_ref_e = line(-2, v_mean, vertice[idx[-1]])
            y_ref_e = f(v_ref_e)
            if y_ref_e < y_ref:
                vertice[idx[-1]] = v_ref_e
                # print('expand')
            else:
                vertice[idx[-1]] = v_ref
                # print('reflection2')
        elif y_ref >= y[idx[-2]]:
            if y_ref < y[idx[-1]]:
                # y_ref<y_{n+1}, outside contraction
                v_ref_c = line(-0.5, v_mean, vertice[idx[-1]])
                y_ref_c = f(v_ref_c)
                if y_ref_c < y_ref:
                    vertice[idx[-1]] = v_ref_c
                # print('outside contraction')
            else:
                # y_ref>=y_{n+1} inside contraction
                v_ref_c = line(0.5, v_mean, vertice[idx[-1]])
                y_ref_c = f(v_ref_c)
                if y_ref_c < y_ref:
                    vertice[idx[-1]] = v_ref_c
                    # print('inside contraction')
                    continue
            # shrinkage
                for ii in range(1, len(vertice)):
                    vertice[ii] = 0.5 * (vertice[0] + vertice[ii])
                    print('shrinkage')
                vertice = vertice_init(vertice[idx[0]], step_length)
        # restart
        # restarting is very important during iteration, for the simpex
        # can easily stucked into a nonoptimal point
        rtol = 2.0 * abs(y[idx[0]] - y[idx[-1]]) / (
            abs(y[idx[0]]) + abs(y[idx[-1]]) + 1e-9)
        if rtol <= tol:
            vertice = vertice_init(vertice[idx[0]], step_length)
    return vertice_max_list, vertice_min_list

 

測試部分,設置未知參數維度為15維,根據未知函數\(f(\mathbf{x})\)的定義易得該函數的最小值為\(-0.5\times 15=-7.5\)。

dim = 15
step_length = 5
v = np.random.randn(dim)
vertice = vertice_init(v, step_length)  # the chioce of step length is cruical

vertice_max_list, vertice_min_list = simplex(
    f, vertice, maxit=2000, step_length=step_length, tol=1e-5)
print('min value is %s' % f(vertice_min_list[-1]))

作圖展示優化步驟。把每步迭代中單純形的最高點和最低點對應函數值畫成曲線,從圖中可以清晰地看出,算法的確卡在了0以上的位置,而且此時最大值和最小值非常接近,說明單純形已經收縮的很厲害。而重啟方法可以幫助我們擺脫這些陷阱,加快收斂進度。通過幾次重啟之后,算法很快收斂到了真正的極值點。注意restart並不會影響單純形的最小值,而是把最大值變的很大(圖中紅線的跳變)。

f_max_list = []
f_min_list=[]
for ii,jj in zip(vertice_max_list,vertice_min_list):
    f_max_list.append(f(ii))
    f_min_list.append(f(jj))

plt.plot(f_max_list,'r',linewidth=2,label='max')
plt.plot(f_min_list,'b',linewidth=2,label='min')
plt.legend(fontsize=15)
plt.show()

最后,用單純形法訓練了一個三層神經網絡,並用它來做簡單的預測。其中神經網絡輸入層、隱藏層、輸出層的神經元數分別是3、2、1,真實的函數關系為\(y=\sum_{i=1}^{N}x_{i}\)。

import numpy as np
from matplotlib import pyplot as plt


class opt_variables:
    def __init__(self, input_dim, hidden_dim, output_dim, v):
        assert (input_dim + 1) * hidden_dim + (
            hidden_dim + 1) * output_dim == v.size, 'dimension mismatch!'
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.v = v
        self.dim = v.size
        w1_ptr = input_dim * hidden_dim
        b1_ptr = hidden_dim + w1_ptr
        w2_ptr = hidden_dim * output_dim + b1_ptr
        b2_ptr = output_dim + w2_ptr
        self.ptrs = np.array([w1_ptr, b1_ptr, w2_ptr, b2_ptr])

    def __add__(self, other):
        if type(other) is type(self):
            assert self.dim == other.dim, 'dimension mismatch!'
            return opt_variables(self.input_dim, self.hidden_dim,
                                 self.output_dim, self.v + other.v)
        return opt_variables(self.input_dim, self.hidden_dim, self.output_dim,
                             self.v + other)

    def __iadd__(self, other):
        if type(other) is type(self):
            assert self.dim == other.dim, 'dimension mismatch!'
            self.v += other.v
        else:
            self.v += other
        return self

    def __sub__(self, other):
        if type(other) is type(self):
            assert self.dim == other.dim, 'dimension mismatch!'
            return opt_variables(self.input_dim, self.hidden_dim,
                                 self.output_dim, self.v - other.v)
        return opt_variables(self.input_dim, self.hidden_dim, self.output_dim,
                             self.v - other)

    def __isub__(self, other):
        if type(other) is type(self):
            assert self.dim == other.dim, 'dimension mismatch!'
            self.v -= other.v
        else:
            self.v -= other
        return self

    def __mul__(self, constant):
        return opt_variables(self.input_dim, self.hidden_dim, self.output_dim,
                             self.v * constant)

    def __rmul__(self, constant):
        return opt_variables(self.input_dim, self.hidden_dim, self.output_dim,
                             self.v * constant)

    def __imul__(self, constant):
        self.v *= constant
        return self

    def __truediv__(self, constant):
        return opt_variables(self.input_dim, self.hidden_dim, self.output_dim,
                             self.v / constant)

    def __rtruediv__(self, constant):
        return opt_variables(self.input_dim, self.hidden_dim, self.output_dim,
                             constant / self.v)

    def __itruediv__(self, constant):
        self.v /= constant
        return self

    def __str__(self):
        return '%s' % self.v


def vertice_init(vertex_0, step_length):
    emat = np.eye(vertex_0.dim) * step_length
    vertice = [vertex_0]
    for ii in range(vertex_0.dim):
        vertice.append(vertex_0 + emat[:, ii])
    return vertice


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def f(v, X, y):
    assert v.ptrs.size == 4, 'dimension mismatch!'
    assert v.input_dim == X.shape[1], 'dimension mismatch!'
    w1 = np.reshape(v.v[:v.ptrs[0]], (v.hidden_dim, v.input_dim))
    b1 = np.reshape(v.v[v.ptrs[0]:v.ptrs[1]], v.hidden_dim)
    w2 = np.reshape(v.v[v.ptrs[1]:v.ptrs[2]], (v.output_dim, v.hidden_dim))
    b2 = np.reshape(v.v[v.ptrs[2]:], v.output_dim)
    loss = 0.0
    for ii in range(X.shape[0]):
        loss += (
            np.dot(w2, sigmoid(np.dot(w1, X[ii, :]) + b1)) + b2 - y[ii])**2
    return loss[0]

def pred(v, X):
    assert v.ptrs.size == 4, 'dimension mismatch!'
    assert v.input_dim == X.shape[1], 'dimension mismatch!'
    w1 = np.reshape(v.v[:v.ptrs[0]], (v.hidden_dim, v.input_dim))
    b1 = np.reshape(v.v[v.ptrs[0]:v.ptrs[1]], v.hidden_dim)
    w2 = np.reshape(v.v[v.ptrs[1]:v.ptrs[2]], (v.output_dim, v.hidden_dim))
    b2 = np.reshape(v.v[v.ptrs[2]:], v.output_dim)
    y_pred = []
    for ii in range(X.shape[0]):
        y_pred.append(np.dot(w2, sigmoid(np.dot(w1, X[ii, :]) + b1)) + b2)

    return np.array(y_pred)


def line(t, v1, v2):
    return (1 - t) * v1 + t * v2


def simplex(f, X, y_real, vertice, maxit=1000, tol=1e-7, step_length=100):

    vertice_max_list = []
    vertice_min_list = []
    for jj in range(maxit):
        y = []
        # evaluate the function value
        for ii in vertice:
            y.append(f(ii, X, y_real))
        y = np.array(y)
        idx = np.argsort(y)  # in descend order
        vertice_max_list.append(vertice[idx[-1]])
        vertice_min_list.append(vertice[idx[0]])
        v_mean = np.mean(vertice)
        v_ref = line(-1, v_mean, vertice[idx[-1]])
        y_ref = f(v_ref, X, y_real)
        if y_ref >= y[idx[0]] and y_ref < y[idx[-2]]:
            # y_0<=y_ref<y_n, reflection (replace v_n+1 with v_ref)
            vertice[idx[-1]] = v_ref
            # print('reflection1')
        elif y_ref < y[idx[0]]:
            # y_ref<y_0, expand
            v_ref_e = line(-2, v_mean, vertice[idx[-1]])
            y_ref_e = f(v_ref_e, X, y_real)
            if y_ref_e < y_ref:
                vertice[idx[-1]] = v_ref_e
                # print('expand')
            else:
                vertice[idx[-1]] = v_ref
                # print('reflection2')
        elif y_ref >= y[idx[-2]]:
            if y_ref < y[idx[-1]]:
                # y_ref<y_{n+1}, outside contraction
                v_ref_c = line(-0.5, v_mean, vertice[idx[-1]])
                y_ref_c = f(v_ref_c, X, y_real)
                if y_ref_c < y_ref:
                    vertice[idx[-1]] = v_ref_c
                # print('outside contraction')
            else:
                # y_ref>=y_{n+1} inside contraction
                v_ref_c = line(0.5, v_mean, vertice[idx[-1]])
                y_ref_c = f(v_ref_c, X, y_real)
                if y_ref_c < y_ref:
                    vertice[idx[-1]] = v_ref_c
                    # print('inside contraction')
                    continue
            # shrinkage
                for ii in range(1, len(vertice)):
                    vertice[ii] = 0.5 * (vertice[0] + vertice[ii])
                    print('shrinkage')
                vertice = vertice_init(vertice[idx[0]], step_length)

        rtol = 2.0 * abs(y[idx[0]] - y[idx[-1]]) / (
            abs(y[idx[0]]) + abs(y[idx[-1]]) + 1e-9)
        if rtol <= tol:
            vertice = vertice_init(vertice[idx[0]], step_length)

    return vertice_max_list, vertice_min_list

# define the 3 layer NN
input_dim = 3
hidden_dim = 2
output_dim = 1
total_dim = (input_dim + 1) * hidden_dim + (hidden_dim + 1) * output_dim

# simplex initialize
v = opt_variables(input_dim, hidden_dim, output_dim, np.random.rand(total_dim))
step_length = 3
vertice = vertice_init(v, step_length)  # the chioce of step length is cruical

# training data
X = np.random.rand(100, 3)
y_real = X.sum(axis=1)

# model training
vertice_max_list, vertice_min_list = simplex(
    f, X, y_real, vertice, maxit=200, tol=1e-3, step_length=step_length)


# plot
f_max_list = []
f_min_list = []
for ii, jj in zip(vertice_max_list, vertice_min_list):
    f_max_list.append(f(ii, X, y_real))
    f_min_list.append(f(jj, X, y_real))

plt.plot(f_max_list, 'r', linewidth=2, label='max')
plt.plot(f_min_list, 'b', linewidth=2, label='min')
plt.legend(fontsize=15)
plt.show()

# prediction

X_test = np.random.rand(100, 3)
y_real_test = X_test.sum(axis=1)


y_pred = pred(vertice_min_list[-1], X_test)

plt.plot(y_real_test, 'r', linewidth=2, label='real')
plt.plot(y_pred, 'b', linewidth=2, label='pred')
plt.legend(fontsize=15)
plt.show()

 

下面分別展示了20次迭代,100次迭代和200次迭代后的模型預測效果:

 

20次迭代

100次迭代

200次迭代

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM