下降單純形法(downhill simplex method)是一個廣泛使用的“derivative free”的優化算法。一般來說它的效率不高,但是文獻[1]提到“the downhill simplex method may frequently be the *best* method to use if the figure of merit is “get something working quickly” for a problem whose computational burden is small.”
單純形法的基本思路是在\(N\)維空間中,構造一個非退化的初始單純形,然后做一系列的幾何操作,如反射、擴展、收縮等,逐步往極值點移動該單純形。由於這些幾何操作的目的基本上都是讓單純形往極小值移動,所以叫下降單純形法。
假設待優化的函數為\(f(\mathbf{x})\),\(N\)維空間里的單純形\(Z\)的\(N+1\)個頂點按照函數值從小到大排列分別為\(\mathbf{x}_{0},\mathbf{x}_{2},\cdots,\mathbf{x}_{N}\),定義
\(\bar{\mathbf{x}}=\sum_{i=0}^{N-1}\mathbf{x}_{i}\)
為\(Z\)中除了頂點\(\mathbf{x}_{N}\)之外其余頂點的中心點。
連接\(\bar{\mathbf{x}}\)和\(\mathbf{x}_{N}\)的直線公式可以寫成:
\(\bar{\mathbf{x}}(t)=(1-t)\bar{\mathbf{x}}+t\mathbf{x}_{N}\)
下降單純形法就是從沿着直線$\bar{\mathbf{x}}(t)$方向的幾個特殊步長尋找\(\mathbf{x}_{N}\)的替代點,使該替代點處的函數值比\(\mathbf{x}_{N}\)更小,如果沒有找到這種替代點,那么就把除了\(\mathbf{x}_{0}\)點之外的其余點往\(\mathbf{x}_{0}\)靠攏。
假設是要求函數的極小值,可以把對應函數值越小的點認為越好,越大的點認為越差。下降單純形法每步迭代過程簡述如下:
- 首先計算最差點沿着直線$\bar{\mathbf{x}}(t)$關於平均點$\bar{\mathbf{x}}$的對稱點
- 如果對稱點介於最好和次差點之間,那么就接受它(reflection);
- 如果對稱點比最好點還好,那么做沿該方向更大膽的嘗試,令\(t=-2\),如果新嘗試點比對稱點更好則接受新嘗試點(expand),否則接受當前對稱點(reflection);
- 如果對稱點介於次差點和最差點之間,那么沿該方向做更小心的嘗試,即令\(t=-0.5\),如果新嘗試點比對稱點更好則接受新嘗試點(outside contraction)
- 如果對稱點比最差點還差,那么沿反方向做嘗試,即令\(t=0.5\),如果新嘗試點比對稱點更好則接受新嘗試點(inside contraction)
- 如果4和5均失敗,即對稱點比次差還要差而且outside contraction與inside contraction均失敗,那么把最好點之外的其他點都朝最好點收縮(shrink)
上述過程如果用區間圖表示會更清晰,區間的三個分界點就是最好、 次差和最差點。對應的偽代碼可以參考文獻 [3] 第9.5節。
上述迭代算法需要提供一個初始單純形,該單純形可以參考文獻[1]給出的方法得到。首先任選一點\(\mathbf{x}_{0}\),然后利用以下公式:
\(\mathbf{x}_{i}=\mathbf{x}_{0}+\lambda\mathbf{e}_{i}\)
其中\(\mathbf{e}_{i}\)代表\(N\)維空間的單位矢量,\(\lambda\)表示步長。
另外,在實際應用中還有一個很重要的訣竅,即重啟(restart)。因為單純形在迭代更新的時候很容易就卡在某個中間位置上,這時單純形的 *最好* 和 *最差* 點幾乎相同,單純形的體積收縮的很小,會大大減慢迭代速度。為了解決這個問題,可以合理設置初始單純形的大小。更有效的,就是可以在單純形卡住的時候通過重新初始化單純形來加快收斂速度。在利用初始化公式的時候,把當前單純形的 *最好* 點作為\(\mathbf{x}_{0}\)保留下來,這樣保證重啟就不會影響之前已經計算的結果。
下面給出了單純形法求解一個簡單問題的python實現,其中待優化函數為\(f(\mathbf{x})=(\mathbf{x}-\mathbf{x}_{0})(\mathbf{x}-\mathbf{x}_{1})\)
首先定義一些需要用到的函數
import numpy as np from matplotlib import pyplot as plt import seaborn as sns def vertice_init(vertex_0, step_length): ''' initialize vertice of the simplex using the following formula: $xi=x0+step_length*ei$ ''' emat = np.eye(vertex_0.size) * step_length vertice = [vertex_0] for ii in range(vertex_0.size): vertice.append(vertex_0 + emat[:, ii]) return vertice def f(v): ''' Evaluation of Function $f$ ''' dim = v.size v0 = np.ones(dim) * 5 v1 = np.ones(dim) * 3 return 0.5 * np.dot(v - v0, v - v1) def line(t, v1, v2): return (1 - t) * v1 + t * v2
接下來定義算法主函數,注意里面的restart部分:
def simplex(f, vertice, maxit=1000, step_length=100, tol=1e-3): vertice_max_list = [] # store the max vertex during each iteration vertice_min_list = [] # store the min vertex during each iteration for jj in range(maxit): y = [] for ii in vertice: y.append(f(ii)) y = np.array(y) # only the highest (worst), next-highest, and lowest (best) vertice # are neeed idx = np.argsort(y) vertice_max_list.append(vertice[idx[-1]]) vertice_min_list.append(vertice[idx[0]]) # centroid of the best n vertice # NOTE: the worst vertex should be excluded, but for simplicity we don't do this v_mean = np.mean(vertice) # compute the candidate vertex and corresponding function vaule v_ref = line(-1, v_mean, vertice[idx[-1]]) y_ref = f(v_ref) if y_ref >= y[idx[0]] and y_ref < y[idx[-2]]: # y_0<=y_ref<y_n, reflection (replace v_n+1 with v_ref) vertice[idx[-1]] = v_ref # print('reflection1') elif y_ref < y[idx[0]]: # y_ref<y_0, expand v_ref_e = line(-2, v_mean, vertice[idx[-1]]) y_ref_e = f(v_ref_e) if y_ref_e < y_ref: vertice[idx[-1]] = v_ref_e # print('expand') else: vertice[idx[-1]] = v_ref # print('reflection2') elif y_ref >= y[idx[-2]]: if y_ref < y[idx[-1]]: # y_ref<y_{n+1}, outside contraction v_ref_c = line(-0.5, v_mean, vertice[idx[-1]]) y_ref_c = f(v_ref_c) if y_ref_c < y_ref: vertice[idx[-1]] = v_ref_c # print('outside contraction') else: # y_ref>=y_{n+1} inside contraction v_ref_c = line(0.5, v_mean, vertice[idx[-1]]) y_ref_c = f(v_ref_c) if y_ref_c < y_ref: vertice[idx[-1]] = v_ref_c # print('inside contraction') continue # shrinkage for ii in range(1, len(vertice)): vertice[ii] = 0.5 * (vertice[0] + vertice[ii]) print('shrinkage') vertice = vertice_init(vertice[idx[0]], step_length) # restart # restarting is very important during iteration, for the simpex # can easily stucked into a nonoptimal point rtol = 2.0 * abs(y[idx[0]] - y[idx[-1]]) / ( abs(y[idx[0]]) + abs(y[idx[-1]]) + 1e-9) if rtol <= tol: vertice = vertice_init(vertice[idx[0]], step_length) return vertice_max_list, vertice_min_list
測試部分,設置未知參數維度為15維,根據未知函數\(f(\mathbf{x})\)的定義易得該函數的最小值為\(-0.5\times 15=-7.5\)。
dim = 15 step_length = 5 v = np.random.randn(dim) vertice = vertice_init(v, step_length) # the chioce of step length is cruical vertice_max_list, vertice_min_list = simplex( f, vertice, maxit=2000, step_length=step_length, tol=1e-5) print('min value is %s' % f(vertice_min_list[-1]))
作圖展示優化步驟。把每步迭代中單純形的最高點和最低點對應函數值畫成曲線,從圖中可以清晰地看出,算法的確卡在了0以上的位置,而且此時最大值和最小值非常接近,說明單純形已經收縮的很厲害。而重啟方法可以幫助我們擺脫這些陷阱,加快收斂進度。通過幾次重啟之后,算法很快收斂到了真正的極值點。注意restart並不會影響單純形的最小值,而是把最大值變的很大(圖中紅線的跳變)。
f_max_list = [] f_min_list=[] for ii,jj in zip(vertice_max_list,vertice_min_list): f_max_list.append(f(ii)) f_min_list.append(f(jj)) plt.plot(f_max_list,'r',linewidth=2,label='max') plt.plot(f_min_list,'b',linewidth=2,label='min') plt.legend(fontsize=15) plt.show()
最后,用單純形法訓練了一個三層神經網絡,並用它來做簡單的預測。其中神經網絡輸入層、隱藏層、輸出層的神經元數分別是3、2、1,真實的函數關系為\(y=\sum_{i=1}^{N}x_{i}\)。
import numpy as np from matplotlib import pyplot as plt class opt_variables: def __init__(self, input_dim, hidden_dim, output_dim, v): assert (input_dim + 1) * hidden_dim + ( hidden_dim + 1) * output_dim == v.size, 'dimension mismatch!' self.input_dim = input_dim self.hidden_dim = hidden_dim self.output_dim = output_dim self.v = v self.dim = v.size w1_ptr = input_dim * hidden_dim b1_ptr = hidden_dim + w1_ptr w2_ptr = hidden_dim * output_dim + b1_ptr b2_ptr = output_dim + w2_ptr self.ptrs = np.array([w1_ptr, b1_ptr, w2_ptr, b2_ptr]) def __add__(self, other): if type(other) is type(self): assert self.dim == other.dim, 'dimension mismatch!' return opt_variables(self.input_dim, self.hidden_dim, self.output_dim, self.v + other.v) return opt_variables(self.input_dim, self.hidden_dim, self.output_dim, self.v + other) def __iadd__(self, other): if type(other) is type(self): assert self.dim == other.dim, 'dimension mismatch!' self.v += other.v else: self.v += other return self def __sub__(self, other): if type(other) is type(self): assert self.dim == other.dim, 'dimension mismatch!' return opt_variables(self.input_dim, self.hidden_dim, self.output_dim, self.v - other.v) return opt_variables(self.input_dim, self.hidden_dim, self.output_dim, self.v - other) def __isub__(self, other): if type(other) is type(self): assert self.dim == other.dim, 'dimension mismatch!' self.v -= other.v else: self.v -= other return self def __mul__(self, constant): return opt_variables(self.input_dim, self.hidden_dim, self.output_dim, self.v * constant) def __rmul__(self, constant): return opt_variables(self.input_dim, self.hidden_dim, self.output_dim, self.v * constant) def __imul__(self, constant): self.v *= constant return self def __truediv__(self, constant): return opt_variables(self.input_dim, self.hidden_dim, self.output_dim, self.v / constant) def __rtruediv__(self, constant): return opt_variables(self.input_dim, self.hidden_dim, self.output_dim, constant / self.v) def __itruediv__(self, constant): self.v /= constant return self def __str__(self): return '%s' % self.v def vertice_init(vertex_0, step_length): emat = np.eye(vertex_0.dim) * step_length vertice = [vertex_0] for ii in range(vertex_0.dim): vertice.append(vertex_0 + emat[:, ii]) return vertice def sigmoid(x): return 1 / (1 + np.exp(-x)) def f(v, X, y): assert v.ptrs.size == 4, 'dimension mismatch!' assert v.input_dim == X.shape[1], 'dimension mismatch!' w1 = np.reshape(v.v[:v.ptrs[0]], (v.hidden_dim, v.input_dim)) b1 = np.reshape(v.v[v.ptrs[0]:v.ptrs[1]], v.hidden_dim) w2 = np.reshape(v.v[v.ptrs[1]:v.ptrs[2]], (v.output_dim, v.hidden_dim)) b2 = np.reshape(v.v[v.ptrs[2]:], v.output_dim) loss = 0.0 for ii in range(X.shape[0]): loss += ( np.dot(w2, sigmoid(np.dot(w1, X[ii, :]) + b1)) + b2 - y[ii])**2 return loss[0] def pred(v, X): assert v.ptrs.size == 4, 'dimension mismatch!' assert v.input_dim == X.shape[1], 'dimension mismatch!' w1 = np.reshape(v.v[:v.ptrs[0]], (v.hidden_dim, v.input_dim)) b1 = np.reshape(v.v[v.ptrs[0]:v.ptrs[1]], v.hidden_dim) w2 = np.reshape(v.v[v.ptrs[1]:v.ptrs[2]], (v.output_dim, v.hidden_dim)) b2 = np.reshape(v.v[v.ptrs[2]:], v.output_dim) y_pred = [] for ii in range(X.shape[0]): y_pred.append(np.dot(w2, sigmoid(np.dot(w1, X[ii, :]) + b1)) + b2) return np.array(y_pred) def line(t, v1, v2): return (1 - t) * v1 + t * v2 def simplex(f, X, y_real, vertice, maxit=1000, tol=1e-7, step_length=100): vertice_max_list = [] vertice_min_list = [] for jj in range(maxit): y = [] # evaluate the function value for ii in vertice: y.append(f(ii, X, y_real)) y = np.array(y) idx = np.argsort(y) # in descend order vertice_max_list.append(vertice[idx[-1]]) vertice_min_list.append(vertice[idx[0]]) v_mean = np.mean(vertice) v_ref = line(-1, v_mean, vertice[idx[-1]]) y_ref = f(v_ref, X, y_real) if y_ref >= y[idx[0]] and y_ref < y[idx[-2]]: # y_0<=y_ref<y_n, reflection (replace v_n+1 with v_ref) vertice[idx[-1]] = v_ref # print('reflection1') elif y_ref < y[idx[0]]: # y_ref<y_0, expand v_ref_e = line(-2, v_mean, vertice[idx[-1]]) y_ref_e = f(v_ref_e, X, y_real) if y_ref_e < y_ref: vertice[idx[-1]] = v_ref_e # print('expand') else: vertice[idx[-1]] = v_ref # print('reflection2') elif y_ref >= y[idx[-2]]: if y_ref < y[idx[-1]]: # y_ref<y_{n+1}, outside contraction v_ref_c = line(-0.5, v_mean, vertice[idx[-1]]) y_ref_c = f(v_ref_c, X, y_real) if y_ref_c < y_ref: vertice[idx[-1]] = v_ref_c # print('outside contraction') else: # y_ref>=y_{n+1} inside contraction v_ref_c = line(0.5, v_mean, vertice[idx[-1]]) y_ref_c = f(v_ref_c, X, y_real) if y_ref_c < y_ref: vertice[idx[-1]] = v_ref_c # print('inside contraction') continue # shrinkage for ii in range(1, len(vertice)): vertice[ii] = 0.5 * (vertice[0] + vertice[ii]) print('shrinkage') vertice = vertice_init(vertice[idx[0]], step_length) rtol = 2.0 * abs(y[idx[0]] - y[idx[-1]]) / ( abs(y[idx[0]]) + abs(y[idx[-1]]) + 1e-9) if rtol <= tol: vertice = vertice_init(vertice[idx[0]], step_length) return vertice_max_list, vertice_min_list # define the 3 layer NN input_dim = 3 hidden_dim = 2 output_dim = 1 total_dim = (input_dim + 1) * hidden_dim + (hidden_dim + 1) * output_dim # simplex initialize v = opt_variables(input_dim, hidden_dim, output_dim, np.random.rand(total_dim)) step_length = 3 vertice = vertice_init(v, step_length) # the chioce of step length is cruical # training data X = np.random.rand(100, 3) y_real = X.sum(axis=1) # model training vertice_max_list, vertice_min_list = simplex( f, X, y_real, vertice, maxit=200, tol=1e-3, step_length=step_length) # plot f_max_list = [] f_min_list = [] for ii, jj in zip(vertice_max_list, vertice_min_list): f_max_list.append(f(ii, X, y_real)) f_min_list.append(f(jj, X, y_real)) plt.plot(f_max_list, 'r', linewidth=2, label='max') plt.plot(f_min_list, 'b', linewidth=2, label='min') plt.legend(fontsize=15) plt.show() # prediction X_test = np.random.rand(100, 3) y_real_test = X_test.sum(axis=1) y_pred = pred(vertice_min_list[-1], X_test) plt.plot(y_real_test, 'r', linewidth=2, label='real') plt.plot(y_pred, 'b', linewidth=2, label='pred') plt.legend(fontsize=15) plt.show()
下面分別展示了20次迭代,100次迭代和200次迭代后的模型預測效果:
20次迭代
100次迭代
200次迭代