本文的目的是解析 ceres-solver AutoDiff 的實現,說明它是一種類似於 matlab 符號運算的方法。
ceres-solver 使用 ceres::CostFunction
作為計算誤差與雅克比的結構。ceres::CostFunction 是一個純虛類,用戶代碼繼承這個類,並通過實現其純虛方法 bool Evaluate(double const* const* parameters, double* residuals, double** jacobians);
提供使用待優化參數塊(parameters)計算誤差(residuals)與雅克比(jacobians) 的方法。對於需要快速驗證想法的用戶,計算雅克比是繁瑣的。
ceres 提供了兩種自動計算雅克比的方法——AutoDiff 與 NumericDiff,用戶可以分別繼承 ceres::AutoDiffCostFunction
、 ceres::NumericDiffCostFunction
以使用這兩種方法。選擇使用這兩種方法之后,用戶代碼僅需告知 ceres 如何使用 parameters 計算 residuals,至於 jacobians 如何計算,ceres 自行尋找方法。
ceres 的 AutoDiff 使用 Dual Number 計算雅克比。所謂 Dual Number 就是將一個實數寫成其自身(為方便將其稱為“大量”)與小量(e)的和,並且定義 \(e^2 = 0\) (在計算一階導數時這么定義)。ceres 實現的 Dual Number 的結構是 ceres::Jet
,Jet 結構中的大量是 T a;
,小量是 Eigen::Matrix<T, N, 1> v;
(此處小量使用一個 Eigen::Vector 表達是介於多元函數對多個變量求導的考慮,后面會解釋)。
文件 jet.h
中有一些注釋解釋 Dual Number 是如何計算導數的。現在摘抄一個注釋中的例子如下。
// To handle derivatives of functions taking multiple arguments, different
// infinitesimals are used, one for each variable to take the derivative of. For
// example, consider a scalar function of two scalar parameters x and y:
//
// f(x, y) = x^2 + x * y
//
// Following the technique above, to compute the derivatives df/dx and df/dy for
// f(1, 3) involves doing two evaluations of f, the first time replacing x with
// x + e, the second time replacing y with y + e.
//
// For df/dx:
//
// f(1 + e, y) = (1 + e)^2 + (1 + e) * 3
// = 1 + 2 * e + 3 + 3 * e
// = 4 + 5 * e
//
// --> df/dx = 5
//
// For df/dy:
//
// f(1, 3 + e) = 1^2 + 1 * (3 + e)
// = 1 + 3 + e
// = 4 + e
//
// --> df/dy = 1
//
求函數 f(x, y) = x^2 + x * y
在 (1, 3) 上現在我用微積分的數學方式計算導數。
以上注釋說明了使用 Dual Number 計算函數 \(f(x,y)=x^2+xy\) 在 \((1, 3)\) 處對 \(x, y\) 的導數的過程。對 x 的偏導,是將 x 用 Dual Number 1 + e 表示,將 y 用實數 3 表示,代入函數式計算,最終得到的 e 的一次項系數就是函數在 (1, 3) 上對 x 的偏導。實際上這是 L'Hospital 法則的計算機實現。現在使用在微積分中學到的方法計算導數。
注釋:(*) 使用一次 L'Hospital 法則,即分子分母分別對 \(\Delta x\) 求一次導數。
分析:在求導數的時候分母一般為 1 次項,即 \(\Delta x\)。使用 L'Hospital 法則,對分母求導,會將 \(\Delta x\) 0 次的項求導消失;而剛好是 \(\Delta x\) 1 次的項求導后是常數;\(\Delta x\) 高於 1 次的項在求導后還會留下 \(\Delta x\),在求極限之后會消失。所以,導數是 1 次項對應的系數,在程序實現中是 e 對應的系數。(但是此處我還沒有考慮 \(\Delta x\) 0 次以下的項,現在搞不定。)同理,求二次導數,就是取 \(\Delta x^2\) 的系數。
緊接着下面的注釋給出了小量為何使用 Eigen::Vector 表示
的解釋。
// To take the gradient of f with the implementation of dual numbers ("jets") in
// this file, it is necessary to create a single jet type which has components
// for the derivative in x and y, and passing them to a templated version of f:
//
// template<typename T>
// T f(const T &x, const T &y) {
// return x * x + x * y;
// }
//
// // The "2" means there should be 2 dual number components.
// // It computes the partial derivative at x=10, y=20.
// Jet<double, 2> x(10, 0); // Pick the 0th dual number for x.
// Jet<double, 2> y(20, 1); // Pick the 1st dual number for y.
// Jet<double, 2> z = f(x, y);
//
// LOG(INFO) << "df/dx = " << z.v[0]
// << "df/dy = " << z.v[1];
//
如果想直接求對所有變量的導數,那么 Dual Number 的 e 的個數就要增加了,有兩個變量,就需要 2 個 e,使用 ceres::Jet<double, 2>
。實驗驗證,可以在使用 AutoDiff 時,於用戶代碼實現的模板函數 operator() 中故意使模板特例化錯誤,檢查 typename 是否特例化為 Jet<double, [N]>,N 是待優化參數的個數(注意,是 parameters 的個數,不是 parameter blocks 的個數)。
Jet 作為 Dual Number 實現求導具體是要實現求導的一般法則與一些基本函數的導數公式。這些相關的內容可以參考 WikiPedia Differentiation rules。現在對一般法則與基本函數的導數分別舉一個在文件 jet.h
中找得到的例子。
“一般法則”舉例乘法法則。在 C++ 中基本運算的 operator 僅有 +, -, *, / 四種,僅需對這四種運算實現對應的 Jet operator 即可。在 Python 中有冪運算符 **
,大概 Python 實現還需要考慮這個吧。
乘法法則在數學中可以表達如下。
在 Jet 中對應的 operator*
如下。
template <typename T, int N>
inline Jet<T, N> operator*(const Jet<T, N>& f, const Jet<T, N>& g) {
return Jet<T, N>(f.a * g.a, f.a * g.v + f.v * g.a);
}
“基本函數的導數”舉例正弦函數。
正弦函數的導數在數學中表達如下。
在 Jet 中對應的函數 sin
如下。
template <typename T, int N>
inline Jet<T, N> sin(const Jet<T, N>& f) {
return Jet<T, N>(sin(f.a), cos(f.a) * f.v);
}
以上兩個例子,代碼中形成的 Jet 的小量,對應於數學公式的導數。
另,需要注意在 ceres 中使用 AutoDiff,模板函數 operator() 計算 residuals 過程中使用到的基礎函數需要從 ceres 中獲得,即不可直接使用 std::sin
函數,應使用 ceres::sin
,以上 sin
函數體內使用到的 sin
, cos
函數是 std::sin
, std::cos
。即 stl 中的模板是無法實例化 ceres::Jet 的。
對於變量負數次冪的處理可以參考代碼 operator/(T s, const Jet<T, N>& g)
,即“Scalar 除以 Jet”。
綜上所述,ceres-solver 使用 ceres::Jet,實現了 AutoDiff。具體的實現,是通過 ceres::Jet 豐富的 operator 與定義的一系列基本函數(的導數)。