線性方程組 Ax =b 除了高斯消元法以外,還有其它的迭代解法,這里我們說的是共軛梯度法。
這里只針對 A 滿足 對稱 ( ), 正定(即
),並且是實系數的,那么我們可以用 梯度下降 和 共軛梯度 來解線性方程組 :
向量 和
是共軛的 (相對於A )如果滿足:
下圖兩兩向量都是針對所在梯度處的矩陣‘共軛’的:
把梯度變換一下,就可以看出‘共軛’其實也就是某種正交:
=============================================
共軛梯度法解:
算法步驟:(from wiki)
---------------------------------------------
python代碼:(源於:Baselines:https://github.com/openai/baselines(強化學習算法))
import numpy as np """共軛梯度下降""" def cg(f_Ax, b, cg_iters=10, callback=None, verbose=False, residual_tol=1e-10): """ Demmel p 312 """ p = b.copy() r = b.copy() x = np.zeros_like(b) rdotr = r.dot(r) fmtstr = "%10i %10.3g %10.3g" titlestr = "%10s %10s %10s" if verbose: print(titlestr % ("iter", "residual norm", "soln norm")) for i in range(cg_iters): if callback is not None: callback(x) if verbose: print(fmtstr % (i, rdotr, np.linalg.norm(x))) z = f_Ax(p) v = rdotr / p.dot(z) x += v*p r -= v*z newrdotr = r.dot(r) mu = newrdotr/rdotr p = r + mu*p rdotr = newrdotr if rdotr < residual_tol: break if callback is not None: callback(x) if verbose: print(fmtstr % (i+1, rdotr, np.linalg.norm(x))) # pylint: disable=W0631 return x
測試代碼:
import numpy as np from gg import cg #導入 共軛梯度函數 cg """ A = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) """ A = np.random.rand(3, 3) # 保證子行列式均為正 A = np.dot(A.T, A) # 生成對稱矩陣 def f_Ax(p): """f_Ax: 輸入變量p為列向量,返回變量為矩陣A矩陣乘以向量p""" return np.dot(A, p) x = np.random.rand(3) b = np.dot(A, x) print("matrix: \n", A) print("x: \n", x) print("b: \n", b) print("...........................") print("顯示計算過程:") result = cg(f_Ax, b, verbose=True) print("matrix A 的特征值:") print(np.linalg.eig(A)[0]) print("實際x:") print(x) print("求得x:") print(result)
結果:
matrix: [[1.33507088 0.69389736 0.579944 ] [0.69389736 0.76303172 0.47845562] [0.579944 0.47845562 0.41679907]] x: [0.40139385 0.12481318 0.38628268] b: [0.84651911 0.55858167 0.45350579] ........................... 顯示計算過程: iter residual norm soln norm 0 1.23 0 1 0.000553 0.523 2 0.000169 0.535 3 4.11e-28 0.571 matrix A 的特征值: [2.12734118 0.31861571 0.06894478] 實際x: [0.40139385 0.12481318 0.38628268] 求得x: [0.40139385 0.12481318 0.38628268]
=============================================
參考:
圖來源:
------------------------------------------------------------------------------