Base Class
Contents
Base Class#
In this section, we are going to talk about:
The
Base
class for the BrainPy ecosystemThe
Collector
to facilitate variable collection and manipulation.
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
brainpy.Base
#
The foundation of BrainPy is brainpy.Base. A Base instance is an object which has variables and methods. All methods in the Base object can be JIT compiled or automatically differentiated. In other words, any class objects that will be JIT compiled or automatically differentiated must inherent from brainpy.Base
.
A Base object can have many variables, children Base objects, integrators, and methods. Below is the implemention of a FitzHugh-Nagumo neuron model as an example.
class FHN(bp.Base):
def __init__(self, num, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
super(FHN, self).__init__(name=name)
# parameters
self.num = num
self.a = a
self.b = b
self.tau = tau
self.Vth = Vth
# variables
self.V = bm.Variable(bm.zeros(num))
self.w = bm.Variable(bm.zeros(num))
self.spike = bm.Variable(bm.zeros(num, dtype=bool))
# integral
self.integral = bp.odeint(method='rk4', f=self.derivative)
def derivative(self, V, w, t, Iext):
dw = (V + self.a - self.b * w) / self.tau
dV = V - V * V * V / 3 - w + Iext
return dV, dw
def update(self, _t, _dt, x):
V, w = self.integral(self.V, self.w, _t, x)
self.spike[:] = bm.logical_and(V > self.Vth, self.V <= self.Vth)
self.w[:] = w
self.V[:] = V
Note this model has three variables: self.V
, self.w
, and self.spike
. It also has an integrator self.integral
.
The naming system#
Every Base object has a unique name. Users can specify a unique name when you instantiate a Base class. A used name will cause an error.
FHN(10, name='X').name
'X'
FHN(10, name='Y').name
'Y'
try:
FHN(10, name='Y').name
except Exception as e:
print(type(e).__name__, ':', e)
If a name is not specified to the Base oject, BrainPy will assign a name for this object automatically. The rule for generating object name is class_name + number_of_instances
. For example, FHN0
, FHN1
, etc.
FHN(10).name
'FHN0'
FHN(10).name
'FHN1'
Therefore, in BrainPy, you can access any object by its unique name, no matter how insignificant this object is.
Collection functions#
Three important collection functions are implemented for each Base object. Specifically, they are:
nodes()
: to collect all instances of Base objects, including children nodes in a node.vars()
: to collect all variables defined in the Base node and in its children nodes.
fhn = FHN(10)
All variables in a Base object can be collected through Base.vars()
. The returned container is a TensorCollector (a subclass of Collector
).
vars = fhn.vars()
vars
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN2.spike': Variable([False, False, False, False, False, False, False, False,
False, False], dtype=bool),
'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
type(vars)
brainpy.base.collector.TensorCollector
All nodes in the model can also be collected through one method Base.nodes()
. The result container is an instance of Collector.
nodes = fhn.nodes()
nodes # note: integrator is also a node
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155a7a6a90>,
'FHN2': <__main__.FHN at 0x2155a7a65e0>}
type(nodes)
brainpy.base.collector.Collector
All integrators can be collected by:
ints = fhn.nodes().subset(bp.integrators.Integrator)
ints
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155a7a6a90>}
type(ints)
brainpy.base.collector.Collector
Now, let’s make a more complicated model by using the previously defined model FHN
.
class FeedForwardCircuit(bp.Base):
def __init__(self, num1, num2, w=0.1, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
super(FeedForwardCircuit, self).__init__(name=name)
self.pre = FHN(num1, a=a, b=b, tau=tau, Vth=Vth)
self.post = FHN(num2, a=a, b=b, tau=tau, Vth=Vth)
self.conn = bm.ones((num1, num2), dtype=bool) * w
bm.fill_diagonal(self.conn, 0.)
def update(self, _t, _dt, x):
self.pre.update(_t, _dt, x)
x2 = self.pre.spike @ self.conn
self.post.update(_t, _dt, x2)
This model FeedForwardCircuit
defines two layers. Each layer is modeled as a FitzHugh-Nagumo model (FHN
). The first layer is densely connected to the second layer. The input to the second layer is the product of the first layer’s spike and the connection strength w
.
net = FeedForwardCircuit(8, 5)
We can retrieve all variables by .vars()
:
net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN3.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN4.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
'FHN4.spike': Variable([False, False, False, False, False], dtype=bool),
'FHN4.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}
And retrieve all nodes (instances of the Base class) by .nodes()
:
net.nodes()
{'FHN3': <__main__.FHN at 0x2155ace3130>,
'FHN4': <__main__.FHN at 0x2155a798d30>,
'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>,
'RK47': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace38b0>,
'FeedForwardCircuit0': <__main__.FeedForwardCircuit at 0x2155a743c40>}
If we only care about a subtype of class, we can retrieve them through:
net.nodes().subset(bp.ode.ODEIntegrator)
{'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>,
'RK47': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace38b0>}
Absolute paths#
It’s worthy to note that there are two ways to access variables, integrators, and nodes. They are “absolute” paths and “relative” paths. The default way is the absolute path.
For absolute paths, all keys in the resulting Collector (Base.nodes()
) has the format of key = node_name [+ field_name]
.
.nodes() example 1: In the above fhn
instance, there are two nodes: “fnh” and its integrator “fhn.integral”.
fhn.integral.name, fhn.name
('RK45', 'FHN2')
Calling .nodes()
returns their names and models.
fhn.nodes().keys()
dict_keys(['RK45', 'FHN2'])
.nodes() example 2: In the above net
instance, there are five nodes:
net.pre.name, net.post.name, net.pre.integral.name, net.post.integral.name, net.name
('FHN3', 'FHN4', 'RK46', 'RK47', 'FeedForwardCircuit0')
Calling .nodes()
also returns the names and instances of all models.
net.nodes().keys()
dict_keys(['FHN3', 'FHN4', 'RK46', 'RK47', 'FeedForwardCircuit0'])
.vars() example 1: In the above fhn
instance, there are three variables: “V”, “w” and “input”. Calling .vars()
returns a dict of <node_name + var_name, var_value>
.
fhn.vars()
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN2.spike': Variable([False, False, False, False, False, False, False, False,
False, False], dtype=bool),
'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
.vars() example 2: This also applies in the net
instance:
net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN3.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN4.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
'FHN4.spike': Variable([False, False, False, False, False], dtype=bool),
'FHN4.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}
Relative paths#
Variables, integrators, and nodes can also be accessed by relative paths. For example, the pre
instance in the net
can be accessed by
net.pre
<__main__.FHN at 0x2155ace3130>
Relative paths preserve the dependence relationship. For example, all nodes retrieved from the perspective of net
are:
net.nodes(method='relative')
{'': <__main__.FeedForwardCircuit at 0x2155a743c40>,
'pre': <__main__.FHN at 0x2155ace3130>,
'post': <__main__.FHN at 0x2155a798d30>,
'pre.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>,
'post.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace38b0>}
However, nodes retrieved from the start point of net.pre
will be:
net.pre.nodes('relative')
{'': <__main__.FHN at 0x2155ace3130>,
'integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>}
Variables can also br relatively inferred from the model. For example, variables that can be relatively accessed from net
include:
net.vars('relative')
{'pre.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'pre.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
'pre.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'post.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
'post.spike': Variable([False, False, False, False, False], dtype=bool),
'post.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}
While variables relatively accessed from net.post
are:
net.post.vars('relative')
{'V': Variable([0., 0., 0., 0., 0.], dtype=float32),
'spike': Variable([False, False, False, False, False], dtype=bool),
'w': Variable([0., 0., 0., 0., 0.], dtype=float32)}
Elements in containers#
One drawback of collection functions is that they don not look for elements in list, dict or any other container structure.
class ATest(bp.Base):
def __init__(self):
super(ATest, self).__init__()
self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
self.sub_nodes = {'a': FHN(10), 'b': FHN(5)}
t1 = ATest()
The above class defines a list of variables, and a dict of children nodes, but the variables and children nodes cannot be retrieved from the collection functions vars()
and nodes()
.
t1.vars()
{}
t1.nodes()
{'ATest0': <__main__.ATest at 0x2155ae60a00>}
To solve this problem, BrianPy provides implicit_vars
and implicit_nodes
(an instance of “dict”) to hold variables and nodes in container structures. Variables registered in implicit_vars
and integrators and nodes registered in implicit_nodes
can be retrieved by collection functions.
class AnotherTest(bp.Base):
def __init__(self):
super(AnotherTest, self).__init__()
self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
self.sub_nodes = {'a': FHN(10, name='T1'), 'b': FHN(5, name='T2')}
self.register_implicit_vars({f'v{i}': v for i, v in enumerate(self.all_vars)} # the input must be a dict
)
self.register_implicit_nodes({k: v for k, v in self.sub_nodes.items()} # the input must be a dict
)
t2 = AnotherTest()
# This model has two "FHN" instances, each "FHN" instance has one integrator.
# Therefore, there are five Base objects.
t2.nodes()
{'T1': <__main__.FHN at 0x2155ae6bca0>,
'T2': <__main__.FHN at 0x2155ae6b3a0>,
'RK410': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ae6f8e0>,
'RK411': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ae66610>,
'AnotherTest0': <__main__.AnotherTest at 0x2155ae6f0d0>}
# This model has two FHN node, each of which has three variables.
# Moreover, this model has two implicit variables.
t2.vars()
{'T1.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'T1.spike': Variable([False, False, False, False, False, False, False, False,
False, False], dtype=bool),
'T1.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'T2.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
'T2.spike': Variable([False, False, False, False, False], dtype=bool),
'T2.w': Variable([0., 0., 0., 0., 0.], dtype=float32),
'AnotherTest0.v0': Variable([0., 0., 0., 0., 0.], dtype=float32),
'AnotherTest0.v1': Variable([1., 1., 1., 1., 1., 1.], dtype=float32)}
Saving and loading#
Because Base.vars()
returns a Python dictionary object Collector, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to BrainPy models. Therefore, each Base object has standard exporting and loading methods (for more details, please see Saving and Loading). Specifically, they are implemented by Base.save_states()
and Base.load_states()
.
Save#
Base.save_states(PATH, [vars])
Models exported from BrainPy support various Python standard file formats, including
HDF5:
.h5
,.hdf5
.npz
(NumPy file format).pkl
(Python’spickle
utility).mat
(Matlab file format)
net.save_states('./data/net.h5')
net.save_states('./data/net.pkl')
Load#
Base.load_states(PATH)
net.load_states('./data/net.h5')
net.load_states('./data/net.pkl')
Collector#
Collection functions return an brainpy.Collector
that is a dictionary mapping names to elements. It has some useful methods.
subset()
#
Collector.subset(cls)
returns a part of elements whose type is the given cls
. For example, Base.nodes()
returns all instances of Base class. If you are only interested in one type, like ODEIntegrator
, you can use:
net.nodes().subset(bp.ode.ODEIntegrator)
{'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>,
'RK47': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace38b0>}
Actually, Collector.subset(cls)
travels all the elements in this collection, and find the element whose type matches the given cls
.
unique()
#
It is common in machine learning that weights are shared with several objects, or the same weight can be accessed by various dependence relationships. Collection functions of Base usually return a collection in which the same value have multiple keys. The duplicate elements will not be automatically excluded. However, it is important not to apply operations such as gradient descent twice or more to the same elements.
Therefore, the Collector provides Collector.unique()
to handle this problem automatically. Collector.unique()
returns a copy of collection in which all elements are unique.
class ModelA(bp.Base):
def __init__(self):
super(ModelA, self).__init__()
self.a = bm.Variable(bm.zeros(5))
class SharedA(bp.Base):
def __init__(self, source):
super(SharedA, self).__init__()
self.source = source
self.a = source.a # shared variable
class Group(bp.Base):
def __init__(self):
super(Group, self).__init__()
self.A = ModelA()
self.A_shared = SharedA(self.A)
g = Group()
g.vars('relative') # save Variable can be accessed by three paths
{'A.a': Variable([0., 0., 0., 0., 0.], dtype=float32),
'A_shared.a': Variable([0., 0., 0., 0., 0.], dtype=float32),
'A_shared.source.a': Variable([0., 0., 0., 0., 0.], dtype=float32)}
g.vars('relative').unique() # only return a unique path
{'A.a': Variable([0., 0., 0., 0., 0.], dtype=float32)}
g.nodes('relative') # "ModelA" is accessed twice
{'': <__main__.Group at 0x2155b13b550>,
'A': <__main__.ModelA at 0x2155b13b460>,
'A_shared': <__main__.SharedA at 0x2155a7a6580>,
'A_shared.source': <__main__.ModelA at 0x2155b13b460>}
g.nodes('relative').unique()
{'': <__main__.Group at 0x2155b13b550>,
'A': <__main__.ModelA at 0x2155b13b460>,
'A_shared': <__main__.SharedA at 0x2155a7a6580>}
update()
#
The Collector can also catch potential conflicts during the assignment. The bracket assignment of a Collector ([key]
) and Collector.update()
will check whether the same key is mapped to a different value. If it is, an error will occur.
tc = bp.Collector({'a': bm.zeros(10)})
tc
{'a': JaxArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
try:
tc['a'] = bm.zeros(1) # same key "a", different tensor
except Exception as e:
print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [0.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].
try:
tc.update({'a': bm.ones(1)}) # same key "a", different tensor
except Exception as e:
print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [1.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].
replace()
#
Collector.replace(old_key, new_value)
is used to update the value of a key.
tc
{'a': JaxArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
tc.replace('a', bm.ones(3))
tc
{'a': JaxArray([1., 1., 1.], dtype=float32)}
__add()__
#
Two Collectors can be merged.
a = bp.Collector({'a': bm.zeros(10)})
b = bp.Collector({'b': bm.ones(10)})
a + b
{'a': JaxArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'b': JaxArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)}
TensorCollector#
TensorCollector
is subclass of Collector
, but it is specifically to collect tensors.