keras中激活函數自定義(以mish函數為列)


 若使用keras框架直接編輯函數調用會導致編譯錯誤。因此,有2種方法可以實現keras的調用,其一使用lamda函數調用,

其二使用繼承Layer層調用(如下代碼)。如果使用繼承layer層調用,那你可以將你想要實現的方式,通過call函數編輯,而

call函數的參數一般為輸入特征變量[batch,h,w,c],具體實現如下代碼:

class Mish(Layer):
'''
Mish Activation Function.
.. math::
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x}))
tanh=(1 - e^{-2x})/(1 + e^{-2x})
Shape:
- Input: Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
- Output: Same shape as the input.
Examples:
>>> X_input = Input(input_shape)
>>> X = Mish()(X_input)
'''

def __init__(self, **kwargs):
super(Mish, self).__init__(**kwargs)
self.supports_masking = True

def call(self, inputs):
return inputs * K.tanh(K.softplus(inputs))

def get_config(self):
config = super(Mish, self).get_config()
return config

def compute_output_shape(self, input_shape):
'''
compute_output_shape(self, input_shape):為了能讓Keras內部shape的匹配檢查通過,
這里需要重寫compute_output_shape方法去覆蓋父類中的同名方法,來保證輸出shape是正確的。
父類Layer中的compute_output_shape方法直接返回的是input_shape這明顯是不對的,
所以需要我們重寫這個方法。所以這個方法也是4個要實現的基本方法之一。
'''
return input_shape





有了mish激活函數,該如何調呢?以下代碼簡單演示其調用方式:

cov1=conv2d(卷積參數)(input) # 將輸入input進行卷積操作
Mish()(cov1)  # 將卷積結果輸入定義的激活類中,實現mish激活








免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM