用java寫bp神經網絡(三)


孔子曰,吾日三省吾身。我們如果跟程序打交道,除了一日三省吾身外,還要三日一省吾代碼。看代碼是否可以更簡潔,更易懂,更容易擴展,更通用,算法是否可以再優化,結構是否可以再往上抽象。代碼在不斷的重構過程中,更臻化境。佝僂者承蜩如是,大匠鑄劍亦復如是,藝雖小,其道一也。所謂苟日新,再日新,日日新。

本次對前兩篇文章代碼進行重構,主要重構函數接口體系,和權重矩陣的封裝。

簡單函數

所說函數,是數學概念上的函數。數學上的函數,一般有一自變量$x$(輸入)和對應的值$y=f(x)$(輸出)。其中$x$可以是個數字,一個向量,一個矩陣等等。我們用泛型定義如下:

public interface Function<I,O> {
  O valueAt(I x);
}

 I代表輸入類型,O代表輸出類型。

有的函數是可微的,比如神經網絡的激活函數。可微函數除了是一個函數,還可求出給定$x$處的導數,或者梯度。而且梯度類型與自變量類型一致。用泛型定義如下:

public interface DifferentiableFunction<I,O> extends Function<I,O> {
  I derivativeAt(I x);
}

同時,考慮到某些函數,在求得值和導數時,共同用到了一些中間變量,或者后一個可以用到前一個的結果,我們定義了PreCaculate接口。當我們判定一個函數實現了PreCaculate接口時,我們首先調用它的PreCaculate接口,讓它預先計算出一些有用的中間變量,然后再調用其valueAt和derivativeAt求得其具體的值,這樣可以節省一些操作步驟。定義如下:

public interface PreCaculate<I> {
	void preCaculate(I x);
}

 基於上面的定義,我們定義神經網絡的激活函數的類型為:

public interface ActivationFunction extends DifferentiableFunction<DoubleMatrix, DoubleMatrix>

 即我們激活函數是一個可微函數,輸入為一個矩陣(netResult),輸出為一個矩陣(finalResult)。

帶參函數

有些函數,除了自變量外,還有一些其它的系數,或者參數,我們稱為超參數。比如誤差函數,目標值為參數,輸出值為自變量。這類函數接口定義如下:

public interface ParamFunction<I,O,P> {
	O valueAt(I x,P param);
}

 類似的,定義其微分接口如下:

public interface DifferentiableParamFunction<I, O, P> extends ParamFunction<I, O, P> {
	I derivativeAt(I x,P param);
}

 我們的誤差函數定義如下:

public interface CostFunction extends DifferentiableParamFunction<DoubleMatrix,DoubleMatrix,DoubleMatrix>

 輸入,輸出,參數都為矩陣。

組合矩陣

在神經網絡的概念中,每兩層之間有一個權重矩陣,偏置矩陣,如果輸入字向量也要調整,那么還有一個字典矩陣。這些所有的矩陣隨着迭代過程不斷更新,以期使誤差函數達到最小。從廣義上來講,訓練樣本就是超參數,這些所有的矩陣為自變量,誤差函數就是優化函數。那么實質上,在調整權重矩陣時,自變量即這一系列的矩陣可以展開拉長拼接成一個超長的向量而已,其內部的結構已無關緊要。在jare的源碼中,是把這些權重矩陣的值存儲在一個長的double[]中,計算完畢后,再從這個doulbe[]中還原出各矩陣的結構。在這里,我們定義了一個類CompactDoubleMatrix名為超矩陣來從更高一層封裝這些矩陣變量,使其對外表現出好像就是一個矩陣。

這個CompactDoubleMatrix的實現方式為,在內部維護一個DoubleMatrix的有序列表List<DoubleMatrix>,然后再執行加減乘除操作時,會批量的對列表中的所有矩陣執行。這樣的封裝,我們隨后會發現將簡化了我們大量代碼。先把完整定義放上來。

public class CompactDoubleMatrix {
	List<DoubleMatrix> mats = new ArrayList<DoubleMatrix>();

	@SafeVarargs
	public CompactDoubleMatrix(List<DoubleMatrix>... matListArray) {
		super();
		this.append(matListArray);
	}

	public CompactDoubleMatrix(DoubleMatrix... matArray) {
		super();
		this.append(matArray);
	}

	public CompactDoubleMatrix() {
		super();
	}

	public CompactDoubleMatrix addi(CompactDoubleMatrix other) {
		this.assertSize(other);
		for (int i = 0; i < this.length(); i++)
			this.get(i).addi(other.get(i));
		return this;
	}

	public void subi(CompactDoubleMatrix other) {
		this.assertSize(other);
		for (int i = 0; i < this.length(); i++)
			this.get(i).subi(other.get(i));
	}

	public CompactDoubleMatrix add(CompactDoubleMatrix other) {
		this.assertSize(other);
		CompactDoubleMatrix result = new CompactDoubleMatrix();
		for (int i = 0; i < this.length(); i++) {
			result.append(this.get(i).add(other.get(i)));
		}
		return result;
	}

	public CompactDoubleMatrix sub(CompactDoubleMatrix other) {
		this.assertSize(other);
		CompactDoubleMatrix result = new CompactDoubleMatrix();
		for (int i = 0; i < this.length(); i++) {
			result.append(this.get(i).sub(other.get(i)));
		}
		return result;
	}

	public CompactDoubleMatrix mul(CompactDoubleMatrix other) {
		this.assertSize(other);
		CompactDoubleMatrix result = new CompactDoubleMatrix();
		for (int i = 0; i < this.length(); i++) {
			result.append(this.get(i).mul(other.get(i)));
		}
		return result;
	}

	public CompactDoubleMatrix muli(double d) {

		for (int i = 0; i < this.length(); i++) {
			this.get(i).muli(d);
		}
		return this;
	}

	public CompactDoubleMatrix mul(double d) {
		CompactDoubleMatrix result = new CompactDoubleMatrix();
		for (int i = 0; i < this.length(); i++) {
			result.append(this.get(i).mul(d));
		}
		return result;
	}

	public CompactDoubleMatrix dup() {
		CompactDoubleMatrix result = new CompactDoubleMatrix();
		for (int i = 0; i < this.length(); i++) {
			result.append(this.get(i).dup());
		}
		return result;
	}

	public double dot(CompactDoubleMatrix other) {
		double sum = 0;
		for (int i = 0; i < this.length(); i++) {
			sum += this.get(i).dot(other.get(i));
		}
		return sum;
	}

	public double norm() {
		double sum = 0;
		for (int i = 0; i < this.length(); i++) {
			double subNorm = this.get(i).norm2();
			sum += subNorm * subNorm;
		}
		return Math.sqrt(sum);
	}

	public void assertSize(CompactDoubleMatrix other) {
		assert (other != null && this.length() == other.length());
		for (int i = 0; i < this.length(); i++) {
			assert (this.get(i).sameSize(other.get(i)));
		}
	}

	@SuppressWarnings("unchecked")
	public void append(List<DoubleMatrix>... matListArray) {
		for (List<DoubleMatrix> list : matListArray) {
			this.mats.addAll(list);
		}
	}

	public void append(DoubleMatrix... matArray) {
		for (DoubleMatrix mat : matArray)
			this.mats.add(mat);
	}

	public int length() {
		return mats.size();
	}

	public DoubleMatrix get(int index) {
		return this.mats.get(index);
	}

	public DoubleMatrix getLast() {
		return this.mats.get(this.length() - 1);
	}
}

 以上介紹了對各抽象概念的封裝,下章介紹使用這些封裝如何簡化我們的代碼。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM