在深度學習領域,Keras是一個高度封裝的庫並被廣泛應用,可以通過調用其內置網絡模塊(各種網絡層)實現針對性的模型結構;當所需要的網絡層功能不被包含時,則需要通過自定義網絡層或模型實現。
如何在keras框架下自定義層,基本“套路”如下。
一般地,keras中的網絡層是一個類,所以自定義層即編寫一個類,更為重要的是這個類(即自定義層)需要繼承Layer父類,而且需要實現以下四種方法:
- __init __ (self, output_dim, **kwargs)
這個方法是用來初始化並自定義自定義層所需的屬性,比如output_dim;
此外,該方法需要執行super().__init __(**kwargs),這行代碼是執行Layer類中的初始化函數;
當執行上述代碼就沒有必要去管input_shape,weights,trainable等關鍵字參數,因為父類(Layer)的初始化函數實現了它們與layer實例的綁定。
- build(self, input_shape)
這個方法是用來創建層的權重;
在該方法中,根據之前的繼承,通過Layer類的add_weight方法來自定義並添加一個權重矩陣,這個方法需要input_shape參數;
該方法必須設self.built = True,目的是為了保證這個層的權重定義函數build被執行過了;
在built函數中,需要說明這個權重各方面的屬性,比如shape、初始化方式以及可訓練性等信息。
- call(self, x)
這個方法是用來編寫層的功能邏輯;
在該方法中,需要關注傳入call的第一個參數:輸入張量x;x只能是一種形式變量,不能是具體的變量,即它不能被定義;
這個call函數就是該層的計算邏輯,當創建好這個層實例后,該實例可以執行call函數;
可見,這個層的核心應該是一段符號式的輸入張量到輸出張量的計算過程。
- compute_output_shape(self, input_shape)
這個方法是用來保證輸出shape是正確的;
這里重寫compute_output_shape方法去覆蓋父類中的同名方法,來保證輸出的shape符合實際;
父類Layer中的compute_output_shape方法直接返回的是input_shape這明顯是不對的,所以需要重寫該方法。
示例
結合官方文檔的例子,給出如下一個自定義層的代碼:
使用自定義層,就如同使用keras內置網絡層一樣,如下圖所示:(另外,本例使用kears內置的激活函數層ReLU承接自定義層的輸出,從而避免將激活函數的功能加入到自定義層中)