RNN、LSTM TF源碼

RNN

class BasicRNNCell(RNNCell):
  """The most basic RNN cell.
  Args:
    num_units: int, The number of units in the RNN cell.
    activation: Nonlinearity to use.  Default: `tanh`.
    reuse: (optional) Python boolean describing whether to reuse variables
     in an existing scope.  If not `True`, and the existing scope already has
     the given variables, an error is raised.
  """

  def __init__(self, num_units, activation=None, reuse=None):
    super(BasicRNNCell, self).__init__(_reuse=reuse)
    self._num_units = num_units
    self._activation = activation or math_ops.tanh
    self._linear = None

  @property
  def state_size(self):
    return self._num_units

  @property
  def output_size(self):
    return self._num_units

  def call(self, inputs, state):
    """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
    if self._linear is None:
      self._linear = _Linear([inputs, state], self._num_units, True)

    output = self._activation(self._linear([inputs, state]))
    return output, output
  
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
print(cell.state_size)
inputs = tf.placeholder(tf.float32, shape=[32, 100])
h0 = cell.zero_state(32, tf.float32)
output, h1 = cell(inputs=inputs, state=h0)
print(output.shape) #128
print(h1.shape)         #128

#這里我們首先初始化了一個(gè)神經(jīng)元個(gè)數(shù)為 128 的 BasicRNNCell 類,然后構(gòu)造了一個(gè) shape 為 [32, 100] 的變量作為 inputs,其代表 batch_size 為 32, 維度為 100,隨后初始化了初始隱藏狀態(tài),調(diào)用了 zero_state() 方法,最終調(diào)用了其 call() 方法,最后得到 output 和 h1

LSTM

class BasicRNNCell(RNNCell):
  def __init__(self, num_units, forget_bias=1.0,
                 state_is_tuple=True, activation=None, reuse=None):
      super(BasicLSTMCell, self).__init__(_reuse=reuse)
      if not state_is_tuple:
        logging.warn("%s: Using a concatenated state is slower and will soon be "
                     "deprecated.  Use state_is_tuple=True.", self)
      self._num_units = num_units
      self._forget_bias = forget_bias
      self._state_is_tuple = state_is_tuple
      self._activation = activation or math_ops.tanh
      self._linear = None
      
  @property
  def state_size(self):
      return (LSTMStateTuple(self._num_units, self._num_units)
          if self._state_is_tuple else 2 * self._num_units)

  @property
  def output_size(self):
      return self._num_units
    
  def call(self, inputs, state):
      """Long short-term memory cell (LSTM).

      Args:
        inputs: `2-D` tensor with shape `[batch_size x input_size]`.
        state: An `LSTMStateTuple` of state tensors, each shaped
          `[batch_size x self.state_size]`, if `state_is_tuple` has been set to
          `True`.  Otherwise, a `Tensor` shaped
          `[batch_size x 2 * self.state_size]`.

      Returns:
        A pair containing the new hidden state, and the new state (either a
          `LSTMStateTuple` or a concatenated state, depending on
          `state_is_tuple`).
      """
      sigmoid = math_ops.sigmoid
      # Parameters of gates are concatenated into one multiply for efficiency.
      if self._state_is_tuple:
          c, h = state
      else:
          c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)

      if self._linear is None:
          self._linear = _Linear([inputs, h], 4 * self._num_units, True)
      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      i, j, f, o = array_ops.split(
          value=self._linear([inputs, h]), num_or_size_splits=4, axis=1)

      new_c = (
          c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
      new_h = self._activation(new_c) * sigmoid(o)

      if self._state_is_tuple:
          new_state = LSTMStateTuple(new_c, new_h)
      else:
          new_state = array_ops.concat([new_c, new_h], 1)
      return new_h, new_state
    
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
inputs = tf.placeholder(tf.float32, shape=(32, 100))
h0 = cell.zero_state(32, tf.float32)
output, h1 = cell(inputs=inputs, state=h0)

摘自:https://cuiqingcai.com/4925.html

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容