一、基本定义方法
当然,Lambda层仅仅适用于不需要增加训练参数的情形,如果想要实现的功能需要往模型新增参数,那么就必须要用到自定义Layer了。其实这也不复杂,相比于Lambda层只不过代码多了几行,官方文章已经写得很清楚了:
https://keras.io/layers/writing-your-own-keras-layers/
这里把它页面上的例子搬过来:
class MyLayer(Layer): def __init__(self, output_dim, **kwargs): self.output_dim = output_dim # 可以自定义一些属性,方便调用 super(MyLayer, self).__init__(**kwargs) # 必须 def build(self, input_shape): # 添加可训练参数 self.kernel = self.add_weight(name='kernel', shape=(input_shape[1], self.output_dim), initializer='uniform', trainable=True) def call(self, x): # 定义功能,相当于Lambda层的功能函数 return K.dot(x, self.kernel) def compute_output_shape(self, input_shape): # 计算输出形状,如果输入和输出形状一致,那么可以省略,否则最好加上 return (input_shape[0], self.output_dim)
参考链接:https://spaces.ac.cn/archives/5765
如上所示, 其中有三个函数需要我们自己实现:
- build() 用来初始化定义weights, 这里可以用父类的self.add_weight() 函数来初始化数据, 该函数必须将 self.built 设置为True, 以保证该 Layer 已经成功 build , 通常如上所示, 使用 super(MyLayer, self).build(input_shape) 来完成。
- call() 用来执行 Layer 的职能, x就是该层的输入,x与权重kernel做点积,生成新的节点层,即当前 Layer 所有的计算过程均在该函数中完成。
- compute_output_shape() 用来计算输出张量的 shape。
例如输入input=【5,128】,5是batch_size,128是embedding向量的维度,input_shape[0]=5,input_shape[1]=128,假如output_dim=256,所以self.kernel的维度就是【128,256】,最后compute_output_shape的输出维度就是【5,256】。调用方式:Mylayer(256)(input)
二、add_weight源码
def add_weight(self, name, shape, dtype=None, initializer=None, regularizer=None, trainable=True, constraint=None): """Adds a weight variable to the layer. # Arguments name: String, the name for the weight variable. shape: The shape tuple of the weight. dtype: The dtype of the weight. initializer: An Initializer instance (callable). regularizer: An optional Regularizer instance. trainable: A boolean, whether the weight should be trained via backprop or not (assuming that the layer itself is also trainable). constraint: An optional Constraint instance. # Returns The created weight variable. """ initializer = initializers.get(initializer) if dtype is None: dtype = K.floatx() weight = K.variable(initializer(shape), dtype=dtype, name=name, constraint=constraint) if regularizer is not None: self.add_loss(regularizer(weight)) if trainable: self._trainable_weights.append(weight) else: self._non_trainable_weights.append(weight) return weight
从上述代码来看通过 add_weight 创建的参数, trainable 设置 True ,自动纳入训练参数中。