keras自定義網絡層


在深度學習領域,Keras是一個高度封裝的庫並被廣泛應用,可以通過調用其內置網絡模塊(各種網絡層)實現針對性的模型結構;當所需要的網絡層功能不被包含時,則需要通過自定義網絡層或模型實現。

如何在keras框架下自定義層,基本“套路”如下。

一般地,keras中的網絡層是一個類,所以自定義層即編寫一個類,更為重要的是這個類(即自定義層)需要繼承Layer父類,而且需要實現以下四種方法:

  1. __init __ (self, output_dim, **kwargs)

這個方法是用來初始化並自定義自定義層所需的屬性,比如output_dim;
此外,該方法需要執行super().__init __(**kwargs),這行代碼是執行Layer類中的初始化函數;
當執行上述代碼就沒有必要去管input_shape,weights,trainable等關鍵字參數,因為父類(Layer)的初始化函數實現了它們與layer實例的綁定。

  1. build(self, input_shape)

這個方法是用來創建層的權重;
在該方法中,根據之前的繼承,通過Layer類的add_weight方法來自定義並添加一個權重矩陣,這個方法需要input_shape參數;
該方法必須設self.built = True,目的是為了保證這個層的權重定義函數build被執行過了;
在built函數中,需要說明這個權重各方面的屬性,比如shape、初始化方式以及可訓練性等信息。

  1. call(self, x)

這個方法是用來編寫層的功能邏輯;
在該方法中,需要關注傳入call的第一個參數:輸入張量x;x只能是一種形式變量,不能是具體的變量,即它不能被定義;
這個call函數就是該層的計算邏輯,當創建好這個層實例后,該實例可以執行call函數;
可見,這個層的核心應該是一段符號式的輸入張量到輸出張量的計算過程。

  1. compute_output_shape(self, input_shape)

這個方法是用來保證輸出shape是正確的;
這里重寫compute_output_shape方法去覆蓋父類中的同名方法,來保證輸出的shape符合實際;
父類Layer中的compute_output_shape方法直接返回的是input_shape這明顯是不對的,所以需要重寫該方法。

示例

結合官方文檔的例子,給出如下一個自定義層的代碼:

使用自定義層,就如同使用keras內置網絡層一樣,如下圖所示:(另外,本例使用kears內置的激活函數層ReLU承接自定義層的輸出,從而避免將激活函數的功能加入到自定義層中)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM