RNNCell#
- class brainpy.dyn.RNNCell(num_in, num_out, state_initializer=ZeroInit, Wi_initializer=XavierNormal(scale=1.0, mode=fan_avg, in_axis=-2, out_axis=-1, distribution=truncated_normal, rng=[ 479025946 4206239744]), Wh_initializer=XavierNormal(scale=1.0, mode=fan_avg, in_axis=-2, out_axis=-1, distribution=truncated_normal, rng=[ 479025946 4206239744]), b_initializer=ZeroInit, activation='relu', mode=None, train_state=False, name=None)[source]#
Basic fully-connected RNN core.
Given \(x_t\) and the previous hidden state \(h_{t-1}\) the core computes
\[h_t = \mathrm{ReLU}(w_i x_t + b_i + w_h h_{t-1} + b_h)\]The output is equal to the new state, \(h_t\).
- Parameters:
num_in (int) – The dimension of the input vector
num_out (int) – The number of hidden unit in the node.
state_initializer (callable, Initializer, bm.ndarray, jax.numpy.ndarray) – The state initializer.
Wi_initializer (callable, Initializer, bm.ndarray, jax.numpy.ndarray) – The input weight initializer.
Wh_initializer (callable, Initializer, bm.ndarray, jax.numpy.ndarray) – The hidden weight initializer.
b_initializer (optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray) – The bias weight initializer.
activation (str, callable) – The activation function. It can be a string or a callable function. See
brainpy.math.activations
for more details.
- reset_state(batch_or_mode=None, **kwargs)[source]#
Reset function which resets local states in this model.
Simply speaking, this function should implement the logic of resetting of local variables in this node.
See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details.