孔子曰,吾日三省吾身。我們如果跟程序打交道,除了一日三省吾身外,還要三日一省吾代碼。看代碼是否可以更簡潔,更易懂,更容易擴展,更通用,算法是否可以再優化,結構是否可以再往上抽象。代碼在不斷的重構過程中,更臻化境。佝僂者承蜩如是,大匠鑄劍亦復如是,藝雖小,其道一也。所謂苟日新,再日新,日日新。
本次對前兩篇文章代碼進行重構,主要重構函數接口體系,和權重矩陣的封裝。
簡單函數
所說函數,是數學概念上的函數。數學上的函數,一般有一自變量$x$(輸入)和對應的值$y=f(x)$(輸出)。其中$x$可以是個數字,一個向量,一個矩陣等等。我們用泛型定義如下:
public interface Function<I,O> {
O valueAt(I x);
}
I代表輸入類型,O代表輸出類型。
有的函數是可微的,比如神經網絡的激活函數。可微函數除了是一個函數,還可求出給定$x$處的導數,或者梯度。而且梯度類型與自變量類型一致。用泛型定義如下:
public interface DifferentiableFunction<I,O> extends Function<I,O> {
I derivativeAt(I x);
}
同時,考慮到某些函數,在求得值和導數時,共同用到了一些中間變量,或者后一個可以用到前一個的結果,我們定義了PreCaculate接口。當我們判定一個函數實現了PreCaculate接口時,我們首先調用它的PreCaculate接口,讓它預先計算出一些有用的中間變量,然后再調用其valueAt和derivativeAt求得其具體的值,這樣可以節省一些操作步驟。定義如下:
public interface PreCaculate<I> {
void preCaculate(I x);
}
基於上面的定義,我們定義神經網絡的激活函數的類型為:
public interface ActivationFunction extends DifferentiableFunction<DoubleMatrix, DoubleMatrix>
即我們激活函數是一個可微函數,輸入為一個矩陣(netResult),輸出為一個矩陣(finalResult)。
帶參函數
有些函數,除了自變量外,還有一些其它的系數,或者參數,我們稱為超參數。比如誤差函數,目標值為參數,輸出值為自變量。這類函數接口定義如下:
public interface ParamFunction<I,O,P> {
O valueAt(I x,P param);
}
類似的,定義其微分接口如下:
public interface DifferentiableParamFunction<I, O, P> extends ParamFunction<I, O, P> {
I derivativeAt(I x,P param);
}
我們的誤差函數定義如下:
public interface CostFunction extends DifferentiableParamFunction<DoubleMatrix,DoubleMatrix,DoubleMatrix>
輸入,輸出,參數都為矩陣。
組合矩陣
在神經網絡的概念中,每兩層之間有一個權重矩陣,偏置矩陣,如果輸入字向量也要調整,那么還有一個字典矩陣。這些所有的矩陣隨着迭代過程不斷更新,以期使誤差函數達到最小。從廣義上來講,訓練樣本就是超參數,這些所有的矩陣為自變量,誤差函數就是優化函數。那么實質上,在調整權重矩陣時,自變量即這一系列的矩陣可以展開拉長拼接成一個超長的向量而已,其內部的結構已無關緊要。在jare的源碼中,是把這些權重矩陣的值存儲在一個長的double[]中,計算完畢后,再從這個doulbe[]中還原出各矩陣的結構。在這里,我們定義了一個類CompactDoubleMatrix名為超矩陣來從更高一層封裝這些矩陣變量,使其對外表現出好像就是一個矩陣。
這個CompactDoubleMatrix的實現方式為,在內部維護一個DoubleMatrix的有序列表List<DoubleMatrix>,然后再執行加減乘除操作時,會批量的對列表中的所有矩陣執行。這樣的封裝,我們隨后會發現將簡化了我們大量代碼。先把完整定義放上來。
public class CompactDoubleMatrix {
List<DoubleMatrix> mats = new ArrayList<DoubleMatrix>();
@SafeVarargs
public CompactDoubleMatrix(List<DoubleMatrix>... matListArray) {
super();
this.append(matListArray);
}
public CompactDoubleMatrix(DoubleMatrix... matArray) {
super();
this.append(matArray);
}
public CompactDoubleMatrix() {
super();
}
public CompactDoubleMatrix addi(CompactDoubleMatrix other) {
this.assertSize(other);
for (int i = 0; i < this.length(); i++)
this.get(i).addi(other.get(i));
return this;
}
public void subi(CompactDoubleMatrix other) {
this.assertSize(other);
for (int i = 0; i < this.length(); i++)
this.get(i).subi(other.get(i));
}
public CompactDoubleMatrix add(CompactDoubleMatrix other) {
this.assertSize(other);
CompactDoubleMatrix result = new CompactDoubleMatrix();
for (int i = 0; i < this.length(); i++) {
result.append(this.get(i).add(other.get(i)));
}
return result;
}
public CompactDoubleMatrix sub(CompactDoubleMatrix other) {
this.assertSize(other);
CompactDoubleMatrix result = new CompactDoubleMatrix();
for (int i = 0; i < this.length(); i++) {
result.append(this.get(i).sub(other.get(i)));
}
return result;
}
public CompactDoubleMatrix mul(CompactDoubleMatrix other) {
this.assertSize(other);
CompactDoubleMatrix result = new CompactDoubleMatrix();
for (int i = 0; i < this.length(); i++) {
result.append(this.get(i).mul(other.get(i)));
}
return result;
}
public CompactDoubleMatrix muli(double d) {
for (int i = 0; i < this.length(); i++) {
this.get(i).muli(d);
}
return this;
}
public CompactDoubleMatrix mul(double d) {
CompactDoubleMatrix result = new CompactDoubleMatrix();
for (int i = 0; i < this.length(); i++) {
result.append(this.get(i).mul(d));
}
return result;
}
public CompactDoubleMatrix dup() {
CompactDoubleMatrix result = new CompactDoubleMatrix();
for (int i = 0; i < this.length(); i++) {
result.append(this.get(i).dup());
}
return result;
}
public double dot(CompactDoubleMatrix other) {
double sum = 0;
for (int i = 0; i < this.length(); i++) {
sum += this.get(i).dot(other.get(i));
}
return sum;
}
public double norm() {
double sum = 0;
for (int i = 0; i < this.length(); i++) {
double subNorm = this.get(i).norm2();
sum += subNorm * subNorm;
}
return Math.sqrt(sum);
}
public void assertSize(CompactDoubleMatrix other) {
assert (other != null && this.length() == other.length());
for (int i = 0; i < this.length(); i++) {
assert (this.get(i).sameSize(other.get(i)));
}
}
@SuppressWarnings("unchecked")
public void append(List<DoubleMatrix>... matListArray) {
for (List<DoubleMatrix> list : matListArray) {
this.mats.addAll(list);
}
}
public void append(DoubleMatrix... matArray) {
for (DoubleMatrix mat : matArray)
this.mats.add(mat);
}
public int length() {
return mats.size();
}
public DoubleMatrix get(int index) {
return this.mats.get(index);
}
public DoubleMatrix getLast() {
return this.mats.get(this.length() - 1);
}
}
以上介紹了對各抽象概念的封裝,下章介紹使用這些封裝如何簡化我們的代碼。
