keras中自定義Layer


最近在學習SSD的源碼,其中有兩個自定的層,特此學習一下並記錄。

 1 import keras.backend as K
 2 from keras.engine.topology import InputSpec
 3 from keras.engine.topology import Layer
 4 import numpy as np
 5 
 6 class L2Normalization(Layer):
 7     '''
 8     Performs L2 normalization on the input tensor with a learnable scaling parameter
 9     as described in the paper "Parsenet: Looking Wider to See Better" (see references)
10     and as used in the original SSD model.
11 
12     Arguments:
13         gamma_init (int): The initial scaling parameter. Defaults to 20 following the
14             SSD paper.
15 
16     Input shape:
17         4D tensor of shape `(batch, channels, height, width)` if `dim_ordering = 'th'`
18         or `(batch, height, width, channels)` if `dim_ordering = 'tf'`.
19 
20     Returns:
21         The scaled tensor. Same shape as the input tensor.
22     '''
23 
24     def __init__(self, gamma_init=20, **kwargs):
25         if K.image_dim_ordering() == 'tf':
26             self.axis = 3
27         else:
28             self.axis = 1
29         self.gamma_init = gamma_init
30         super(L2Normalization, self).__init__(**kwargs)
31 
32     def build(self, input_shape):
33         self.input_spec = [InputSpec(shape=input_shape)]
34         gamma = self.gamma_init * np.ones((input_shape[self.axis],))
35         self.gamma = K.variable(gamma, name='{}_gamma'.format(self.name))
36         self.trainable_weights = [self.gamma]
37         super(L2Normalization, self).build(input_shape)
38 
39     def call(self, x, mask=None):
40         output = K.l2_normalize(x, self.axis)
41         output *= self.gamma
42         return output

首先說一下這個層是用來做什么的。就是對於每一個通道進行歸一化,不過通道使用的是不同的歸一化參數,也就是說這個參數是需要進行學習的,因此需要通過 自定義層來完成。

在keras中,每個層都是對象,真的,可以通過dir(Layer對象)來查看具有哪些屬性。

具體說來:

__init__():用來進行初始化的(這不是廢話么),gamma就是要學習的參數。

bulid():是用來創建這層的權重向量的,也就是要學習的參數“殼”。

33:設置該層的input_spec,這個是通過InputSpec函數來實現。

34:分配權重“殼”的實際空間大小

35,:由於底層使用的Tensorflow來進行實現的,因此這里使用Tensorflow中的variable來保存變量。

36:根據keras官網的要求,可訓練的權重是要添加至trainable_weights列表中的

37:我不想說了,官網給的實例都是這么做的。

call():用來進行具體實現操作的。

40:沿着指定的軸對輸入數據進行L2正則化

41:使用學習的gamma來對正則化后的數據進行加權

42:將最后的數據最為該層的返回值,這里由於是和輸入形式相同的,因此就沒有了compute_output_shape函數,如果輸入和輸出的形式不同,就需要進行輸入的調整。

就這樣子吧。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM