Neural networks in BrainPy are used to build dynamical systems. The brainpy.nn module provides various classes representing the nodes of a neural network. All of them are subclasses of the
brainpy.nn.Node base class.
import brainpy as bp import brainpy.math as bm bp.math.set_platform('cpu')
What is a node?#
In BrainPy, the
Node instance is the basic element to form a network model. It is a unit on a graph, connected to other nodes by edges.
In general, each
Node instance in BrainPy has four components:
It is worthy to note that each
Node instance may have multiple feedforward or feedback connections. However, it only has one state and one output.
output component is used in feedforward connections and feedback connections, which means the feedforward and feedback outputs are the same. However, customization of a different feedback output is also easy (see the Customization of a Node tutorial).
Each node has the following attributes:
feedforward_shapes: the shapes of the feedforward inputs.
feedback_shapes: the shapes of the feedback inputs.
output_shape: the output shape of the node.
state: the state of the node. It can be None if the node has no state to hold.
fb_output: the feedback output of the node. It is None when no feedback connections are established to this node. Default, the value of
forward()function output value.
It also has several boolean attributes:
trainable: whether the node is trainable.
is_initialized: whether the node has been initialized.
Creating a node#
A layer can be created as an instance of a
brainpy.nn.Node subclass. For example, a dense layer can be created as follows:
Dense(name=Dense0, forwards=None, feedbacks=None, output=(None, 100))
This will create a dense layer with 100 units.
Of course, if you have known the shapes of the feedforward connections, you can use
Dense(name=Dense1, forwards=((None, 128),), feedbacks=None, output=(None, 100))
This create a densely connected layer which connected to another input layer with 128 dimension.
Naming a node#
For convenience, you can name a layer by specifying the name keyword argument:
bp.nn.Dense(num_unit=100, input_shape=128, name='hidden_layer')
Dense(name=hidden_layer, forwards=((None, 128),), feedbacks=None, output=(None, 100))
Many nodes have their parameters. We can set the parameter of a node with the following methods.
If a tensor variable instance is provided, this is used unchanged as the parameter variable. For example:
l = bp.nn.Dense(num_unit=50, input_shape=10, weight_initializer=bm.random.normal(0, 0.01, size=(10, 50))) l.initialize(num_batch=1) l.Wff.shape
If a callable function (which receives a
shape argument) is provided, the callable will be called with the desired shape to generate suitable initial parameter values. The variable is then initialized with those values. For example:
def init(shape): return bm.random.random(shape) l = bp.nn.Dense(num_unit=30, input_shape=20, weight_initializer=init) l.initialize(num_batch=1) l.Wff.shape
brainpy.init.Initializer instance is provided, the initial parameter values will be generated with the desired shape by using the Initializer instance. For example:
l = bp.nn.Dense(num_unit=100, input_shape=20, weight_initializer=bp.init.Normal(0.01)) l.initialize(num_batch=1) l.Wff.shape
The weight matrix \(W\) of this dense layer will be initialized using samples from a normal distribution with standard deviation 0.01 (see brainpy.initialize for more information).
Some types of parameter variables can also be set to
None at initialization (e.g. biases). In that case, the parameter variable will be omitted. For example, creating a dense layer without biases is done as follows:
l = bp.nn.Dense(num_unit=100, input_shape=20, bias_initializer=None) l.initialize(num_batch=1) print(l.bias)
Calling the node#
The instantiation of a node build a input-to-output function mapping. To get the mapping output, you can directly call the created node.
l = bp.nn.Dense(num_unit=10, input_shape=20) l.initialize()
JaxArray([[ 0.7788163 , 0.6352515 , 0.9846623 , 0.97518134, -1.0947354 , 0.29821265, -0.9927582 , -0.00511351, 0.6623081 , 0.72418994]], dtype=float32)
JaxArray([[ 0.21428639, 0.5546448 , 0.5172446 , 1.2533414 , -0.54073226, 0.6578476 , -0.31080672, 0.25883573, -0.0466502 , 0.50195456], [ 0.91855824, 0.503054 , 1.1109638 , 0.707477 , -0.8442794 , -0.12064239, -0.81839114, -0.2828313 , -0.660355 , 0.20748737]], dtype=float32)
Moreover, JIT the created model is also applicable.
jit_l = bm.jit(l)
%timeit l(bm.random.random((2, 20)))
2.34 ms ± 370 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit jit_l(bm.random.random((2, 20)))
2.04 ms ± 54.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Setting the node to be trainable or non-trainable can be easily achieved. This is controlled by the
trainable argument when initializing a node.
For example, for a non-trainable dense layer, the weights and bias are JaxArray instances.
l = bp.nn.Dense(num_unit=3, input_shape=4, trainable=False) l.initialize(num_batch=1) l.Wff
JaxArray([[ 0.56564915, -0.70626205, 0.03569109], [-0.10908064, -0.63869774, -0.37541717], [-0.80857176, 0.22993006, 0.02752776], [ 0.32151228, -0.45234612, 0.9239818 ]], dtype=float32)
When creating a layer with trainable setting,
TrainVar will be created for them and initialized automatically. For example:
l = bp.nn.Dense(num_unit=3, input_shape=4, trainable=True) l.initialize(num_batch=1) l.Wff
TrainVar([[-0.20390746, 0.7101851 , -0.2881384 ], [ 0.07779109, -1.1979834 , 0.09109607], [-0.41889605, 0.3983429 , -1.1674007 ], [-0.14914905, -1.1085916 , -0.10857478]], dtype=float32)
Moreover, for a subclass of
state can be set to be trainable or not trainable by
state_trainable argument. When setting
state_trainable=True for an instance of
brainpy.nn.RecurrentNode, a new attribute .train_state will be created.
rnn = bp.nn.VanillaRNN(3, input_shape=(1,), state_trainable=True) rnn.initialize(3) rnn.train_state
TrainVar([0.7986958 , 0.3421112 , 0.24420719], dtype=float32)
Note the difference between the .train_state and the original .state:
.train_state has no batch axis.
node.initialize()function, all values in the .state will be filled with .train_state.
Variable([[0.7986958 , 0.3421112 , 0.24420719], [0.7986958 , 0.3421112 , 0.24420719], [0.7986958 , 0.3421112 , 0.24420719]], dtype=float32)