Node Specification#

@Chaoming Wang

Neural networks in BrainPy are used to build dynamical systems. The brainpy.nn module provides various classes representing the nodes of a neural network. All of them are subclasses of the brainpy.nn.Node base class.

import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

What is a node?#

In BrainPy, the Node instance is the basic element to form a network model. It is a unit on a graph, connected to other nodes by edges.

In general, each Node instance in BrainPy has four components:

  • Feedforward inputs

  • Feedback inputs

  • State

  • Output

It is worthy to note that each Node instance may have multiple feedforward or feedback connections. However, it only has one state and one output. output component is used in feedforward connections and feedback connections, which means the feedforward and feedback outputs are the same. However, customization of a different feedback output is also easy (see the Customization of a Node tutorial).

Each node has the following attributes:

  • feedforward_shapes: the shapes of the feedforward inputs.

  • feedback_shapes: the shapes of the feedback inputs.

  • output_shape: the output shape of the node.

  • state: the state of the node. It can be None if the node has no state to hold.

  • fb_output: the feedback output of the node. It is None when no feedback connections are established to this node. Default, the value of fb_output is the forward() function output value.

It also has several boolean attributes:

  • trainable: whether the node is trainable.

  • is_initialized: whether the node has been initialized.

Creating a node#

A layer can be created as an instance of a brainpy.nn.Node subclass. For example, a dense layer can be created as follows:

bp.nn.Dense(num_unit=100) 
Dense(name=Dense0, forwards=None, 
      feedbacks=None, output=(None, 100))

This will create a dense layer with 100 units.

Of course, if you have known the shapes of the feedforward connections, you can use input_shape.

bp.nn.Dense(num_unit=100, input_shape=128) 
Dense(name=Dense1, forwards=((None, 128),), 
      feedbacks=None, output=(None, 100))

This create a densely connected layer which connected to another input layer with 128 dimension.

Naming a node#

For convenience, you can name a layer by specifying the name keyword argument:

bp.nn.Dense(num_unit=100, input_shape=128, name='hidden_layer')
Dense(name=hidden_layer, forwards=((None, 128),), 
      feedbacks=None, output=(None, 100))

Initializing parameters#

Many nodes have their parameters. We can set the parameter of a node with the following methods.

  • Tensors

If a tensor variable instance is provided, this is used unchanged as the parameter variable. For example:

l = bp.nn.Dense(num_unit=50, input_shape=10, 
                weight_initializer=bm.random.normal(0, 0.01, size=(10, 50)))
l.initialize(num_batch=1)

l.Wff.shape
(10, 50)
  • Callable function

If a callable function (which receives a shape argument) is provided, the callable will be called with the desired shape to generate suitable initial parameter values. The variable is then initialized with those values. For example:

def init(shape):
    return bm.random.random(shape)

l = bp.nn.Dense(num_unit=30, input_shape=20, weight_initializer=init)
l.initialize(num_batch=1)

l.Wff.shape
(20, 30)
  • Instance of brainpy.init.Initializer

If a brainpy.init.Initializer instance is provided, the initial parameter values will be generated with the desired shape by using the Initializer instance. For example:

l = bp.nn.Dense(num_unit=100, input_shape=20, 
                weight_initializer=bp.init.Normal(0.01))
l.initialize(num_batch=1)

l.Wff.shape
(20, 100)

The weight matrix \(W\) of this dense layer will be initialized using samples from a normal distribution with standard deviation 0.01 (see brainpy.initialize for more information).

  • None parameter

Some types of parameter variables can also be set to None at initialization (e.g. biases). In that case, the parameter variable will be omitted. For example, creating a dense layer without biases is done as follows:

l = bp.nn.Dense(num_unit=100, input_shape=20, bias_initializer=None)
l.initialize(num_batch=1)

print(l.bias)
None

Calling the node#

The instantiation of a node build a input-to-output function mapping. To get the mapping output, you can directly call the created node.

l = bp.nn.Dense(num_unit=10, input_shape=20)
l.initialize()
l(bm.random.random((1, 20)))
JaxArray([[ 0.7788163 ,  0.6352515 ,  0.9846623 ,  0.97518134,
           -1.0947354 ,  0.29821265, -0.9927582 , -0.00511351,
            0.6623081 ,  0.72418994]], dtype=float32)
l(bm.random.random((2, 20)))
JaxArray([[ 0.21428639,  0.5546448 ,  0.5172446 ,  1.2533414 ,
           -0.54073226,  0.6578476 , -0.31080672,  0.25883573,
           -0.0466502 ,  0.50195456],
          [ 0.91855824,  0.503054  ,  1.1109638 ,  0.707477  ,
           -0.8442794 , -0.12064239, -0.81839114, -0.2828313 ,
           -0.660355  ,  0.20748737]], dtype=float32)

Moreover, JIT the created model is also applicable.

jit_l = bm.jit(l)
%timeit l(bm.random.random((2, 20)))
2.34 ms ± 370 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit jit_l(bm.random.random((2, 20)))
2.04 ms ± 54.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

trainable settings#

Setting the node to be trainable or non-trainable can be easily achieved. This is controlled by the trainable argument when initializing a node.

For example, for a non-trainable dense layer, the weights and bias are JaxArray instances.

l = bp.nn.Dense(num_unit=3, input_shape=4, trainable=False)
l.initialize(num_batch=1)

l.Wff
JaxArray([[ 0.56564915, -0.70626205,  0.03569109],
          [-0.10908064, -0.63869774, -0.37541717],
          [-0.80857176,  0.22993006,  0.02752776],
          [ 0.32151228, -0.45234612,  0.9239818 ]], dtype=float32)

When creating a layer with trainable setting, TrainVar will be created for them and initialized automatically. For example:

l = bp.nn.Dense(num_unit=3, input_shape=4, trainable=True)
l.initialize(num_batch=1)

l.Wff
TrainVar([[-0.20390746,  0.7101851 , -0.2881384 ],
          [ 0.07779109, -1.1979834 ,  0.09109607],
          [-0.41889605,  0.3983429 , -1.1674007 ],
          [-0.14914905, -1.1085916 , -0.10857478]], dtype=float32)

Moreover, for a subclass of brainpy.nn.RecurrentNode, the state can be set to be trainable or not trainable by state_trainable argument. When setting state_trainable=True for an instance of brainpy.nn.RecurrentNode, a new attribute .train_state will be created.

rnn = bp.nn.VanillaRNN(3, input_shape=(1,), state_trainable=True)
rnn.initialize(3)

rnn.train_state
TrainVar([0.7986958 , 0.3421112 , 0.24420719], dtype=float32)

Note the difference between the .train_state and the original .state:

  1. .train_state has no batch axis.

  2. When using node.init_state() or node.initialize() function, all values in the .state will be filled with .train_state.

rnn.state
Variable([[0.7986958 , 0.3421112 , 0.24420719],
          [0.7986958 , 0.3421112 , 0.24420719],
          [0.7986958 , 0.3421112 , 0.24420719]], dtype=float32)