keras 多輸入多輸出實驗,融合層


官方文檔雖然有多輸入多輸出的例子[英文] [譯文],但是作為使用者,對於keras多輸入多輸出存在一定疑惑

1 網絡層能不能間隔使用,也就是生成Deep Residual Learning。

2 網絡連接的時候,merge層鏈接,能不能自定義merge網絡?

merge子類網絡層有:add、Subtract、Multiply、Average、Maximum、Minimum、Concatenate、Dot這九個網絡層

merge源代碼在github可查看

先分析merge父類代碼:

  1 class _Merge(Layer):
  2     """Generic merge layer for elementwise merge functions.
  3     Used to implement `Sum`, `Average`, etc.
  4     # Arguments
  5         **kwargs: standard layer keyword arguments.
  6     """
  7 
  8     def __init__(self, **kwargs):
  9         super(_Merge, self).__init__(**kwargs)
 10         self.supports_masking = True
 11 
 12     def _merge_function(self, inputs):
 13         raise NotImplementedError
 14 
 15     def _compute_elemwise_op_output_shape(self, shape1, shape2):
 16         """Computes the shape of the resultant of an elementwise operation.
 17         # Arguments
 18             shape1: tuple or None. Shape of the first tensor
 19             shape2: tuple or None. Shape of the second tensor
 20         # Returns
 21             expected output shape when an element-wise operation is
 22             carried out on 2 tensors with shapes shape1 and shape2.
 23             tuple or None.
 24         # Raises
 25             ValueError: if shape1 and shape2 are not compatible for
 26                 element-wise operations.
 27         """
 28         if None in [shape1, shape2]:
 29             return None
 30         elif len(shape1) < len(shape2):
 31             return self._compute_elemwise_op_output_shape(shape2, shape1)
 32         elif len(shape2) == 0:
 33             return shape1
 34         output_shape = list(shape1[:-len(shape2)])
 35         for i, j in zip(shape1[-len(shape2):], shape2):
 36             if i is None or j is None:
 37                 output_shape.append(None)
 38             elif i == 1:
 39                 output_shape.append(j)
 40             elif j == 1:
 41                 output_shape.append(i)
 42             else:
 43                 if i != j:
 44                     raise ValueError('Operands could not be broadcast '
 45                                      'together with shapes ' +
 46                                      str(shape1) + ' ' + str(shape2))
 47                 output_shape.append(i)
 48         return tuple(output_shape)
 49 
 50     def build(self, input_shape):
 51         # Used purely for shape validation.
 52         if not isinstance(input_shape, list):
 53             raise ValueError('A merge layer should be called '
 54                              'on a list of inputs.')
 55         if len(input_shape) < 2:
 56             raise ValueError('A merge layer should be called '
 57                              'on a list of at least 2 inputs. '
 58                              'Got ' + str(len(input_shape)) + ' inputs.')
 59         batch_sizes = [s[0] for s in input_shape if s is not None]
 60         batch_sizes = set(batch_sizes)
 61         batch_sizes -= set([None])
 62         if len(batch_sizes) > 1:
 63             raise ValueError('Can not merge tensors with different '
 64                              'batch sizes. Got tensors with shapes : ' +
 65                              str(input_shape))
 66         if input_shape[0] is None:
 67             output_shape = None
 68         else:
 69             output_shape = input_shape[0][1:]
 70         for i in range(1, len(input_shape)):
 71             if input_shape[i] is None:
 72                 shape = None
 73             else:
 74                 shape = input_shape[i][1:]
 75             output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
 76         # If the inputs have different ranks, we have to reshape them
 77         # to make them broadcastable.
 78         if None not in input_shape and len(set(map(len, input_shape))) == 1:
 79             self._reshape_required = False
 80         else:
 81             self._reshape_required = True
 82 
 83     def call(self, inputs):
 84 #返回函數
 85         if self._reshape_required:
 86             reshaped_inputs = []
 87             input_ndims = list(map(K.ndim, inputs))
 88             if None not in input_ndims:
 89                 # If ranks of all inputs are available,
 90                 # we simply expand each of them at axis=1
 91                 # until all of them have the same rank.
 92                 max_ndim = max(input_ndims)
 93                 for x in inputs:
 94                     x_ndim = K.ndim(x)
 95                     for _ in range(max_ndim - x_ndim):
 96                         x = K.expand_dims(x, 1)
 97                     reshaped_inputs.append(x)
 98                 return self._merge_function(reshaped_inputs)
 99             else:
100                 # Transpose all inputs so that batch size is the last dimension.
101                 # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size)
102                 transposed = False
103                 for x in inputs:
104                     x_ndim = K.ndim(x)
105                     if x_ndim is None:
106                         x_shape = K.shape(x)
107                         batch_size = x_shape[0]
108                         new_shape = K.concatenate([x_shape[1:], K.expand_dims(batch_size)])
109                         x_transposed = K.reshape(x, K.stack([batch_size, K.prod(x_shape[1:])]))
110                         x_transposed = K.permute_dimensions(x_transposed, (1, 0))
111                         x_transposed = K.reshape(x_transposed, new_shape)
112                         reshaped_inputs.append(x_transposed)
113                         transposed = True
114                     elif x_ndim > 1:
115                         dims = list(range(1, x_ndim)) + [0]
116                         reshaped_inputs.append(K.permute_dimensions(x, dims))
117                         transposed = True
118                     else:
119                         # We don't transpose inputs if they are 1D vectors or scalars.
120                         reshaped_inputs.append(x)
121                 y = self._merge_function(reshaped_inputs)
122                 y_ndim = K.ndim(y)
123                 if transposed:
124                     # If inputs have been transposed, we have to transpose the output too.
125                     if y_ndim is None:
126                         y_shape = K.shape(y)
127                         y_ndim = K.shape(y_shape)[0]
128                         batch_size = y_shape[y_ndim - 1]
129                         new_shape = K.concatenate([K.expand_dims(batch_size), y_shape[:y_ndim - 1]])
130                         y = K.reshape(y, (-1, batch_size))
131                         y = K.permute_dimensions(y, (1, 0))
132                         y = K.reshape(y, new_shape)
133                     elif y_ndim > 1:
134                         dims = [y_ndim - 1] + list(range(y_ndim - 1))
135                         y = K.permute_dimensions(y, dims)
136                 return y
137         else:
138             return self._merge_function(inputs)
139 
140     def compute_output_shape(self, input_shape):
141 #返回值的shape設置
142         if input_shape[0] is None:
143             output_shape = None
144         else:
145             output_shape = input_shape[0][1:]
146         for i in range(1, len(input_shape)):
147             if input_shape[i] is None:
148                 shape = None
149             else:
150                 shape = input_shape[i][1:]
151             output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
152         batch_sizes = [s[0] for s in input_shape if s is not None]
153         batch_sizes = set(batch_sizes)
154         batch_sizes -= set([None])
155         if len(batch_sizes) == 1:
156             output_shape = (list(batch_sizes)[0],) + output_shape
157         else:
158             output_shape = (None,) + output_shape
159         return output_shape
160 
161     def compute_mask(self, inputs, mask=None):
162         if mask is None:
163             return None
164         if not isinstance(mask, list):
165             raise ValueError('`mask` should be a list.')
166         if not isinstance(inputs, list):
167             raise ValueError('`inputs` should be a list.')
168         if len(mask) != len(inputs):
169             raise ValueError('The lists `inputs` and `mask` '
170                              'should have the same length.')
171         if all([m is None for m in mask]):
172             return None
173         masks = [K.expand_dims(m, 0) for m in mask if m is not None]
174         return K.all(K.concatenate(masks, axis=0), axis=0, keepdims=False)

merge父類中調用各類子類層的函數,其實就是直接實例化子類:

def add(inputs, **kwargs):
    """Functional interface to the `Add` layer.
    # Arguments
        inputs: A list of input tensors (at least 2).
        **kwargs: Standard layer keyword arguments.
    # Returns
        A tensor, the sum of the inputs.
    # Examples
    ```python
        import keras
        input1 = keras.layers.Input(shape=(16,))
        x1 = keras.layers.Dense(8, activation='relu')(input1)
        input2 = keras.layers.Input(shape=(32,))
        x2 = keras.layers.Dense(8, activation='relu')(input2)
        added = keras.layers.add([x1, x2])
        out = keras.layers.Dense(4)(added)
        model = keras.models.Model(inputs=[input1, input2], outputs=out)
    ```
    """
    return Add(**kwargs)(inputs)


def subtract(inputs, **kwargs):
    """Functional interface to the `Subtract` layer.
    # Arguments
        inputs: A list of input tensors (exactly 2).
        **kwargs: Standard layer keyword arguments.
    # Returns
        A tensor, the difference of the inputs.
    # Examples
    ```python
        import keras
        input1 = keras.layers.Input(shape=(16,))
        x1 = keras.layers.Dense(8, activation='relu')(input1)
        input2 = keras.layers.Input(shape=(32,))
        x2 = keras.layers.Dense(8, activation='relu')(input2)
        subtracted = keras.layers.subtract([x1, x2])
        out = keras.layers.Dense(4)(subtracted)
        model = keras.models.Model(inputs=[input1, input2], outputs=out)
    ```
    """
    return Subtract(**kwargs)(inputs)


def multiply(inputs, **kwargs):
    """Functional interface to the `Multiply` layer.
    # Arguments
        inputs: A list of input tensors (at least 2).
        **kwargs: Standard layer keyword arguments.
    # Returns
        A tensor, the element-wise product of the inputs.
    """
    return Multiply(**kwargs)(inputs)


def average(inputs, **kwargs):
    """Functional interface to the `Average` layer.
    # Arguments
        inputs: A list of input tensors (at least 2).
        **kwargs: Standard layer keyword arguments.
    # Returns
        A tensor, the average of the inputs.
    """
    return Average(**kwargs)(inputs)


def maximum(inputs, **kwargs):
    """Functional interface to the `Maximum` layer.
    # Arguments
        inputs: A list of input tensors (at least 2).
        **kwargs: Standard layer keyword arguments.
    # Returns
        A tensor, the element-wise maximum of the inputs.
    """
    return Maximum(**kwargs)(inputs)


def minimum(inputs, **kwargs):
    """Functional interface to the `Minimum` layer.
    # Arguments
        inputs: A list of input tensors (at least 2).
        **kwargs: Standard layer keyword arguments.
    # Returns
        A tensor, the element-wise minimum of the inputs.
    """
    return Minimum(**kwargs)(inputs)


def concatenate(inputs, axis=-1, **kwargs):
    """Functional interface to the `Concatenate` layer.
    # Arguments
        inputs: A list of input tensors (at least 2).
        axis: Concatenation axis.
        **kwargs: Standard layer keyword arguments.
    # Returns
        A tensor, the concatenation of the inputs alongside axis `axis`.
    """
    return Concatenate(axis=axis, **kwargs)(inputs)


def dot(inputs, axes, normalize=False, **kwargs):
    """Functional interface to the `Dot` layer.
    # Arguments
        inputs: A list of input tensors (at least 2).
        axes: Integer or tuple of integers,
            axis or axes along which to take the dot product.
        normalize: Whether to L2-normalize samples along the
            dot product axis before taking the dot product.
            If set to True, then the output of the dot product
            is the cosine proximity between the two samples.
        **kwargs: Standard layer keyword arguments.
    # Returns
        A tensor, the dot product of the samples from the inputs.
    """
    return Dot(axes=axes, normalize=normalize, **kwargs)(inputs)

簡單的子層,只需要重載_merge_function,其它函數繼承父類

Add層:

class Add(_Merge):
    """Layer that adds a list of inputs.
    It takes as input a list of tensors,
    all of the same shape, and returns
    a single tensor (also of the same shape).
    # Examples
    ```python
        import keras
        input1 = keras.layers.Input(shape=(16,))
        x1 = keras.layers.Dense(8, activation='relu')(input1)
        input2 = keras.layers.Input(shape=(32,))
        x2 = keras.layers.Dense(8, activation='relu')(input2)
        added = keras.layers.Add()([x1, x2])  # equivalent to added = keras.layers.add([x1, x2])
        out = keras.layers.Dense(4)(added)
        model = keras.models.Model(inputs=[input1, input2], outputs=out)
    ```
    """
#把所有輸入都與第一個輸入相加,意味着你可以使用兩個以上的網絡層輸入……
    def _merge_function(self, inputs):
        output = inputs[0]
        for i in range(1, len(inputs)):
            output += inputs[i]
        return output

Subtract層:

class Subtract(_Merge):
    """Layer that subtracts two inputs.
    It takes as input a list of tensors of size 2,
    both of the same shape, and returns a single tensor, (inputs[0] - inputs[1]),
    also of the same shape.
    # Examples
    ```python
        import keras
        input1 = keras.layers.Input(shape=(16,))
        x1 = keras.layers.Dense(8, activation='relu')(input1)
        input2 = keras.layers.Input(shape=(32,))
        x2 = keras.layers.Dense(8, activation='relu')(input2)
        # Equivalent to subtracted = keras.layers.subtract([x1, x2])
        subtracted = keras.layers.Subtract()([x1, x2])
        out = keras.layers.Dense(4)(subtracted)
        model = keras.models.Model(inputs=[input1, input2], outputs=out)
    ```
    """
#輸入的層數只能為兩個,第一個層減去第二個層
    def _merge_function(self, inputs):
        if len(inputs) != 2:
            raise ValueError('`Subtract` layer should be called '
                             'on exactly 2 inputs')
        if inputs[0]._keras_shape != inputs[1]._keras_shape:
            raise ValueError('`Subtract` layer should be called '
                             'on inputs of the same shape')
        return inputs[0] - inputs[1]

Multiply層:

class Multiply(_Merge):
#其他的層都與第一層相乘,合並的層數可以無窮
    """Layer that multiplies (element-wise) a list of inputs.
    It takes as input a list of tensors,
    all of the same shape, and returns
    a single tensor (also of the same shape).
    """

    def _merge_function(self, inputs):
        output = inputs[0]
        for i in range(1, len(inputs)):
            output *= inputs[i]
        return output

Average層:多層求平均值

Maximum層:多層中的最大值

Minimum層:多層中的最小值

Concatenate層:

 1 class Concatenate(_Merge):
 2 #由於連接層的復雜性,所以需要自定義,weghts大小,和該層的各個屬性。
 3 #根據需要的坐標系,連接網絡層
 4     """Layer that concatenates a list of inputs.
 5     It takes as input a list of tensors,
 6     all of the same shape expect for the concatenation axis,
 7     and returns a single tensor, the concatenation of all inputs.
 8     # Arguments
 9         axis: Axis along which to concatenate.
10         **kwargs: standard layer keyword arguments.
11     """
12 
13     def __init__(self, axis=-1, **kwargs):
14         super(Concatenate, self).__init__(**kwargs)
15         self.axis = axis
16         self.supports_masking = True
17 
18     def build(self, input_shape):
19         # Used purely for shape validation.
20         if not isinstance(input_shape, list):
21             raise ValueError('`Concatenate` layer should be called '
22                              'on a list of inputs')
23         if all([shape is None for shape in input_shape]):
24             return
25         reduced_inputs_shapes = [list(shape) for shape in input_shape]
26         shape_set = set()
27         for i in range(len(reduced_inputs_shapes)):
28             del reduced_inputs_shapes[i][self.axis]
29             shape_set.add(tuple(reduced_inputs_shapes[i]))
30         if len(shape_set) > 1:
31             raise ValueError('`Concatenate` layer requires '
32                              'inputs with matching shapes '
33                              'except for the concat axis. '
34                              'Got inputs shapes: %s' % (input_shape))
35 #tensorflow代碼實現返回
36     def call(self, inputs):
37         if not isinstance(inputs, list):
38             raise ValueError('A `Concatenate` layer should be called '
39                              'on a list of inputs.')
40         return K.concatenate(inputs, axis=self.axis)
41 #設置該層輸出值的shape大小
42     def compute_output_shape(self, input_shape):
43         if not isinstance(input_shape, list):
44             raise ValueError('A `Concatenate` layer should be called '
45                              'on a list of inputs.')
46         input_shapes = input_shape
47         output_shape = list(input_shapes[0])
48         for shape in input_shapes[1:]:
49             if output_shape[self.axis] is None or shape[self.axis] is None:
50                 output_shape[self.axis] = None
51                 break
52             output_shape[self.axis] += shape[self.axis]
53         return tuple(output_shape)
54 #有無mask元素(屏蔽元素)
55     def compute_mask(self, inputs, mask=None):
56         if mask is None:
57             return None
58         if not isinstance(mask, list):
59             raise ValueError('`mask` should be a list.')
60         if not isinstance(inputs, list):
61             raise ValueError('`inputs` should be a list.')
62         if len(mask) != len(inputs):
63             raise ValueError('The lists `inputs` and `mask` '
64                              'should have the same length.')
65         if all([m is None for m in mask]):
66             return None
67         # Make a list of masks while making sure
68         # the dimensionality of each mask
69         # is the same as the corresponding input.
70         masks = []
71         for input_i, mask_i in zip(inputs, mask):
72             if mask_i is None:
73                 # Input is unmasked. Append all 1s to masks,
74                 # but cast it to bool first
75                 masks.append(K.cast(K.ones_like(input_i), 'bool'))
76             elif K.ndim(mask_i) < K.ndim(input_i):
77                 # Mask is smaller than the input, expand it
78                 masks.append(K.expand_dims(mask_i))
79             else:
80                 masks.append(mask_i)
81         concatenated = K.concatenate(masks, axis=self.axis)
82         return K.all(concatenated, axis=-1, keepdims=False)
83 
84     def get_config(self):
85         config = {
86             'axis': self.axis,
87         }
88 #super申明使用父類設置
89         base_config = super(Concatenate, self).get_config()
90         return dict(list(base_config.items()) + list(config.items()))

Dot層:計算向量積,融合的層數為2

  1 class Dot(_Merge):
  2     """Layer that computes a dot product between samples in two tensors.
  3     E.g. if applied to two tensors `a` and `b` of shape `(batch_size, n)`,
  4     the output will be a tensor of shape `(batch_size, 1)`
  5     where each entry `i` will be the dot product between
  6     `a[i]` and `b[i]`.
  7     # Arguments
  8         axes: Integer or tuple of integers,
  9             axis or axes along which to take the dot product.
 10         normalize: Whether to L2-normalize samples along the
 11             dot product axis before taking the dot product.
 12             If set to True, then the output of the dot product
 13             is the cosine proximity between the two samples.
 14         **kwargs: Standard layer keyword arguments.
 15     """
 16 
 17     def __init__(self, axes, normalize=False, **kwargs):
 18         super(Dot, self).__init__(**kwargs)
 19         if not isinstance(axes, int):
 20             if not isinstance(axes, (list, tuple)):
 21                 raise TypeError('Invalid type for `axes` - '
 22                                 'should be a list or an int.')
 23             if len(axes) != 2:
 24                 raise ValueError('Invalid format for `axes` - '
 25                                  'should contain two elements.')
 26             if not isinstance(axes[0], int) or not isinstance(axes[1], int):
 27                 raise ValueError('Invalid format for `axes` - '
 28                                  'list elements should be "int".')
 29         self.axes = axes
 30         self.normalize = normalize
 31         self.supports_masking = True
 32 
 33     def build(self, input_shape):
 34         # Used purely for shape validation.
 35         if not isinstance(input_shape, list) or len(input_shape) != 2:
 36             raise ValueError('A `Dot` layer should be called '
 37                              'on a list of 2 inputs.')
 38         shape1 = input_shape[0]
 39         shape2 = input_shape[1]
 40         if shape1 is None or shape2 is None:
 41             return
 42         if isinstance(self.axes, int):
 43             if self.axes < 0:
 44                 axes = [self.axes % len(shape1), self.axes % len(shape2)]
 45             else:
 46                 axes = [self.axes] * 2
 47         else:
 48             axes = self.axes
 49         if shape1[axes[0]] != shape2[axes[1]]:
 50             raise ValueError(
 51                 'Dimension incompatibility '
 52                 '%s != %s. ' % (shape1[axes[0]], shape2[axes[1]]) +
 53                 'Layer shapes: %s, %s' % (shape1, shape2))
 54 #實現向量積,操作,根據axis,進行操作,具體操作語句為k.batch_dot(x1,x2)
 55     def call(self, inputs):
 56         x1 = inputs[0]
 57         x2 = inputs[1]
 58         if isinstance(self.axes, int):
 59             if self.axes < 0:
 60                 axes = [self.axes % K.ndim(x1), self.axes % K.ndim(x2)]
 61             else:
 62                 axes = [self.axes] * 2
 63         else:
 64             axes = []
 65             for i in range(len(self.axes)):
 66                 if self.axes[i] < 0:
 67                     axes.append(self.axes[i] % K.ndim(inputs[i]))
 68                 else:
 69                     axes.append(self.axes[i])
 70         if self.normalize:
 71             x1 = K.l2_normalize(x1, axis=axes[0])
 72             x2 = K.l2_normalize(x2, axis=axes[1])
 73         output = K.batch_dot(x1, x2, axes)
 74         return output
 75 
 76     def compute_output_shape(self, input_shape):
 77         if not isinstance(input_shape, list) or len(input_shape) != 2:
 78             raise ValueError('A `Dot` layer should be called '
 79                              'on a list of 2 inputs.')
 80         shape1 = list(input_shape[0])
 81         shape2 = list(input_shape[1])
 82         if isinstance(self.axes, int):
 83             if self.axes < 0:
 84                 axes = [self.axes % len(shape1), self.axes % len(shape2)]
 85             else:
 86                 axes = [self.axes] * 2
 87         else:
 88             axes = self.axes
 89         shape1.pop(axes[0])
 90         shape2.pop(axes[1])
 91         shape2.pop(0)
 92         output_shape = shape1 + shape2
 93         if len(output_shape) == 1:
 94             output_shape += [1]
 95         return tuple(output_shape)
 96 
 97     def compute_mask(self, inputs, mask=None):
 98         return None
 99 
100     def get_config(self):
101         config = {
102             'axes': self.axes,
103             'normalize': self.normalize,
104         }
105         base_config = super(Dot, self).get_config()
106         return dict(list(base_config.items()) + list(config.items()))

由於知道各個融合成實現的原理,所以能夠自定義融合層:


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM