tensorflow源碼分析——BasicLSTMCell


BasicLSTMCell 是最簡單的LSTMCell,源碼位於:/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py。
BasicLSTMCell 繼承了RNNCell,源碼位於:/tensorflow/python/ops/rnn_cell_impl.py
注意事項:
1. input_size 這個參數不能使用,使用的是num_units
2. state_is_tuple 官方建議設置為True。此時,輸入和輸出的states為c(cell狀態)和h(輸出)的二元組
3. 輸入、輸出、cell的維度相同,都是 batch_size * num_units,
cell = tf.contrib.rnn.BasicLSTMCell(num_units, forget_bias=0.0, state_is_tuple=True)  #指定num_units
_initial_state = cell.zero_state(batch_size, tf.float32)                   #指定batch_size,將c和h全部初始化為0,shape全是batch_size * num_units,

4.
class BasicLSTMCell(RNNCell):
  """Basic LSTM recurrent network cell.

  The implementation is based on: http://arxiv.org/abs/1409.2329.

  We add forget_bias (default: 1) to the biases of the forget gate in order to
  reduce the scale of forgetting in the beginning of the training.

  It does not allow cell clipping, a projection layer, and does not
  use peep-hole connections: it is the basic baseline.

  For advanced models, please use the full LSTMCell that follows.
  """

  def __init__(self, num_units, forget_bias=1.0, input_size=None,
               state_is_tuple=True, activation=tanh):
    """Initialize the basic LSTM cell.

    Args:
      num_units: int, The number of units in the LSTM cell.
      forget_bias: float, The bias added to forget gates (see above).
      input_size: Deprecated and unused.
      state_is_tuple: If True, accepted and returned states are 2-tuples of
        the `c_state` and `m_state`.  If False, they are concatenated
        along the column axis.  The latter behavior will soon be deprecated.
      activation: Activation function of the inner states.
    """
    if not state_is_tuple:
      logging.warn("%s: Using a concatenated state is slower and will soon be "
                   "deprecated.  Use state_is_tuple=True.", self)
    if input_size is not None:
      logging.warn("%s: The input_size parameter is deprecated.", self)
    self._num_units = num_units
    self._forget_bias = forget_bias
    self._state_is_tuple = state_is_tuple
    self._activation = activation

  @property
  def state_size(self):
    return (LSTMStateTuple(self._num_units, self._num_units)
            if self._state_is_tuple else 2 * self._num_units)

  @property
  def output_size(self):
    return self._num_units

  def __call__(self, inputs, state, scope=None):
    """Long short-term memory cell (LSTM)."""
    with vs.variable_scope(scope or "basic_lstm_cell"):
      # Parameters of gates are concatenated into one multiply for efficiency.
      if self._state_is_tuple:
        c, h = state
      else:
        c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)

    # 線性計算 concat = [inputs, h]W + b
    # 線性計算,分配W和b,W的shape為(2*num_units, 4*num_units), b的shape為(4*num_units,),共包含有四套參數,
# concat shape(batch_size, 4*num_units)
  # 注意:只有cell 的input和output的size相等時才可以這樣計算,否則要定義兩套W,b.每套再包含四套參數 concat
= _linear([inputs, h], 4 * self._num_units, True, scope=scope) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j)) new_h = self._activation(new_c) * sigmoid(o) if self._state_is_tuple: new_state = LSTMStateTuple(new_c, new_h) else: new_state = array_ops.concat([new_c, new_h], 1) return new_h, new_state

 5. lstm層,每一batch的運算

        with tf.variable_scope("RNN"):
            for time_step in range(num_steps):
                if time_step > 0: tf.get_variable_scope().reuse_variables()
                (cell_output, state) = cell(inputs[:, time_step, :], state)
                outputs.append(cell_output)

6. 每一epoch

7.全部運算


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM