class brainpy.dyn.GRUCell(num_in, num_out, Wi_initializer=Orthogonal(scale=1.0, axis=-1, rng=[ 479025946 4206239744]), Wh_initializer=Orthogonal(scale=1.0, axis=-1, rng=[ 479025946 4206239744]), b_initializer=ZeroInit, state_initializer=ZeroInit, activation='tanh', mode=None, train_state=False, name=None)[source]#

Gated Recurrent Unit.

The implementation is based on (Chung, et al., 2014) [1] with biases.

Given \(x_t\) and the previous state \(h_{t-1}\) the core computes

\[\begin{split}\begin{array}{ll} z_t &= \sigma(W_{iz} x_t + W_{hz} h_{t-1} + b_z) \\ r_t &= \sigma(W_{ir} x_t + W_{hr} h_{t-1} + b_r) \\ a_t &= \tanh(W_{ia} x_t + W_{ha} (r_t \bigodot h_{t-1}) + b_a) \\ h_t &= (1 - z_t) \bigodot h_{t-1} + z_t \bigodot a_t \end{array}\end{split}\]

where \(z_t\) and \(r_t\) are reset and update gates.

The output is equal to the new hidden state, \(h_t\).

Warning: Backwards compatibility of GRU weights is currently unsupported.

  • num_in (int) – The dimension of the input vector

  • num_out (int) – The number of hidden unit in the node.

  • state_initializer (callable, Initializer, bm.ndarray, jax.numpy.ndarray) – The state initializer.

  • Wi_initializer (callable, Initializer, bm.ndarray, jax.numpy.ndarray) – The input weight initializer.

  • Wh_initializer (callable, Initializer, bm.ndarray, jax.numpy.ndarray) – The hidden weight initializer.

  • b_initializer (optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray) – The bias weight initializer.

  • activation (str, callable) – The activation function. It can be a string or a callable function. See brainpy.math.activations for more details.


reset_state(batch_or_mode=None, **kwargs)[source]#

Reset function which resets local states in this model.

Simply speaking, this function should implement the logic of resetting of local variables in this node.

See for details.


The function to specify the updating rule.