class brainpy.dyn.LSTMCell(num_in, num_out, Wi_initializer=XavierNormal(scale=1.0, mode=fan_avg, in_axis=-2, out_axis=-1, distribution=truncated_normal, rng=[ 426696668 3289036839]), Wh_initializer=XavierNormal(scale=1.0, mode=fan_avg, in_axis=-2, out_axis=-1, distribution=truncated_normal, rng=[ 426696668 3289036839]), b_initializer=ZeroInit, state_initializer=ZeroInit, activation='tanh', mode=None, train_state=False, name=None)[source]#

Long short-term memory (LSTM) RNN core.

The implementation is based on (zaremba, et al., 2014) [1]. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\) the core computes

\[\begin{split}\begin{array}{ll} i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\ f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\ o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\end{split}\]

where \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.

The output is equal to the new hidden, \(h_t\).


Forget gate initialization: Following (Jozefowicz, et al., 2015) [2] we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

  • num_in (int) – The dimension of the input vector

  • num_out (int) – The number of hidden unit in the node.

  • state_initializer (callable, Initializer, bm.ndarray, jax.numpy.ndarray) – The state initializer.

  • Wi_initializer (callable, Initializer, bm.ndarray, jax.numpy.ndarray) – The input weight initializer.

  • Wh_initializer (callable, Initializer, bm.ndarray, jax.numpy.ndarray) – The hidden weight initializer.

  • b_initializer (optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray) – The bias weight initializer.

  • activation (str, callable) – The activation function. It can be a string or a callable function. See brainpy.math.activations for more details.


property c#

Memory cell.

property h#

Hidden state.

reset_state(batch_or_mode=None, **kwargs)[source]#

Reset function which resets local states in this model.

Simply speaking, this function should implement the logic of resetting of local variables in this node.

See for details.


The function to specify the updating rule.