Node Specification
Contents
Node Specification#
Neural networks in BrainPy are used to build dynamical systems. The brainpy.nn module provides various classes representing the nodes of a neural network. All of them are subclasses of the brainpy.nn.Node
base class.
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
What is a node?#
In BrainPy, the Node
instance is the basic element to form a network model. It is a unit on a graph, connected to other nodes by edges.
In general, each Node
instance in BrainPy has four components:
Feedforward inputs
Feedback inputs
State
Output
It is worthy to note that each Node
instance may have multiple feedforward or feedback connections. However, it only has one state and one output. output
component is used in feedforward connections and feedback connections, which means the feedforward and feedback outputs are the same. However, customization of a different feedback output is also easy (see the Customization of a Node tutorial).
Each node has the following attributes:
feedforward_shapes
: the shapes of the feedforward inputs.feedback_shapes
: the shapes of the feedback inputs.output_shape
: the output shape of the node.state
: the state of the node. It can be None if the node has no state to hold.fb_output
: the feedback output of the node. It is None when no feedback connections are established to this node. Default, the value offb_output
is theforward()
function output value.
It also has several boolean attributes:
trainable
: whether the node is trainable.is_initialized
: whether the node has been initialized.
Creating a node#
A layer can be created as an instance of a brainpy.nn.Node
subclass. For example, a dense layer can be created as follows:
bp.nn.Dense(num_unit=100)
Dense(name=Dense0, forwards=None,
feedbacks=None, output=(None, 100))
This will create a dense layer with 100 units.
Of course, if you have known the shapes of the feedforward connections, you can use input_shape
.
bp.nn.Dense(num_unit=100, input_shape=128)
Dense(name=Dense1, forwards=((None, 128),),
feedbacks=None, output=(None, 100))
This create a densely connected layer which connected to another input layer with 128 dimension.
Naming a node#
For convenience, you can name a layer by specifying the name keyword argument:
bp.nn.Dense(num_unit=100, input_shape=128, name='hidden_layer')
Dense(name=hidden_layer, forwards=((None, 128),),
feedbacks=None, output=(None, 100))
Initializing parameters#
Many nodes have their parameters. We can set the parameter of a node with the following methods.
Tensors
If a tensor variable instance is provided, this is used unchanged as the parameter variable. For example:
l = bp.nn.Dense(num_unit=50, input_shape=10,
weight_initializer=bm.random.normal(0, 0.01, size=(10, 50)))
l.initialize(num_batch=1)
l.Wff.shape
(10, 50)
Callable function
If a callable function (which receives a shape
argument) is provided, the callable will be called with the desired shape to generate suitable initial parameter values. The variable is then initialized with those values. For example:
def init(shape):
return bm.random.random(shape)
l = bp.nn.Dense(num_unit=30, input_shape=20, weight_initializer=init)
l.initialize(num_batch=1)
l.Wff.shape
(20, 30)
Instance of
brainpy.init.Initializer
If a brainpy.init.Initializer
instance is provided, the initial parameter values will be generated with the desired shape by using the Initializer instance. For example:
l = bp.nn.Dense(num_unit=100, input_shape=20,
weight_initializer=bp.init.Normal(0.01))
l.initialize(num_batch=1)
l.Wff.shape
(20, 100)
The weight matrix \(W\) of this dense layer will be initialized using samples from a normal distribution with standard deviation 0.01 (see brainpy.initialize for more information).
None parameter
Some types of parameter variables can also be set to None
at initialization (e.g. biases). In that case, the parameter variable will be omitted. For example, creating a dense layer without biases is done as follows:
l = bp.nn.Dense(num_unit=100, input_shape=20, bias_initializer=None)
l.initialize(num_batch=1)
print(l.bias)
None
Calling the node#
The instantiation of a node build a input-to-output function mapping. To get the mapping output, you can directly call the created node.
l = bp.nn.Dense(num_unit=10, input_shape=20)
l.initialize()
l(bm.random.random((1, 20)))
JaxArray([[ 0.7788163 , 0.6352515 , 0.9846623 , 0.97518134,
-1.0947354 , 0.29821265, -0.9927582 , -0.00511351,
0.6623081 , 0.72418994]], dtype=float32)
l(bm.random.random((2, 20)))
JaxArray([[ 0.21428639, 0.5546448 , 0.5172446 , 1.2533414 ,
-0.54073226, 0.6578476 , -0.31080672, 0.25883573,
-0.0466502 , 0.50195456],
[ 0.91855824, 0.503054 , 1.1109638 , 0.707477 ,
-0.8442794 , -0.12064239, -0.81839114, -0.2828313 ,
-0.660355 , 0.20748737]], dtype=float32)
Moreover, JIT the created model is also applicable.
jit_l = bm.jit(l)
%timeit l(bm.random.random((2, 20)))
2.34 ms ± 370 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit jit_l(bm.random.random((2, 20)))
2.04 ms ± 54.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
trainable
settings#
Setting the node to be trainable or non-trainable can be easily achieved. This is controlled by the trainable
argument when initializing a node.
For example, for a non-trainable dense layer, the weights and bias are JaxArray instances.
l = bp.nn.Dense(num_unit=3, input_shape=4, trainable=False)
l.initialize(num_batch=1)
l.Wff
JaxArray([[ 0.56564915, -0.70626205, 0.03569109],
[-0.10908064, -0.63869774, -0.37541717],
[-0.80857176, 0.22993006, 0.02752776],
[ 0.32151228, -0.45234612, 0.9239818 ]], dtype=float32)
When creating a layer with trainable setting, TrainVar
will be created for them and initialized automatically. For example:
l = bp.nn.Dense(num_unit=3, input_shape=4, trainable=True)
l.initialize(num_batch=1)
l.Wff
TrainVar([[-0.20390746, 0.7101851 , -0.2881384 ],
[ 0.07779109, -1.1979834 , 0.09109607],
[-0.41889605, 0.3983429 , -1.1674007 ],
[-0.14914905, -1.1085916 , -0.10857478]], dtype=float32)
Moreover, for a subclass of brainpy.nn.RecurrentNode
, the state
can be set to be trainable or not trainable by state_trainable
argument. When setting state_trainable=True
for an instance of brainpy.nn.RecurrentNode
, a new attribute .train_state will be created.
rnn = bp.nn.VanillaRNN(3, input_shape=(1,), state_trainable=True)
rnn.initialize(3)
rnn.train_state
TrainVar([0.7986958 , 0.3421112 , 0.24420719], dtype=float32)
Note the difference between the .train_state and the original .state:
.train_state has no batch axis.
When using
node.init_state()
ornode.initialize()
function, all values in the .state will be filled with .train_state.
rnn.state
Variable([[0.7986958 , 0.3421112 , 0.24420719],
[0.7986958 , 0.3421112 , 0.24420719],
[0.7986958 , 0.3421112 , 0.24420719]], dtype=float32)