BrainPyObject
and Collector
#
In this section, we are going to talk about:
The
BrainPyObject
class for the BrainPy ecosystemThe
Collector
to facilitate variable collection and manipulation.
import brainpy as bp
import brainpy.math as bm
# bm.set_platform('cpu')
brainpy.BrainPyObject
#
The foundation of BrainPy is brainpy.BrainPyObject. A BrainPyObject instance is an object which has variables and methods. All methods in the BrainPyObject object can be JIT compiled or automatically differentiated. In other words, any class objects that will be JIT compiled or automatically differentiated must inherent from brainpy.BrainPyObject
.
A BrainPyObject object can have many variables, children BrainPyObject objects, integrators, and methods. Below is the implemention of a FitzHugh-Nagumo neuron model as an example.
class FHN(bp.BrainPyObject):
def __init__(self, num, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
super(FHN, self).__init__(name=name)
# parameters
self.num = num
self.a = a
self.b = b
self.tau = tau
self.Vth = Vth
# variables
self.V = bm.Variable(bm.zeros(num))
self.w = bm.Variable(bm.zeros(num))
self.spike = bm.Variable(bm.zeros(num, dtype=bool))
# integral
self.integral = bp.odeint(method='rk4', f=self.derivative)
def derivative(self, V, w, t, Iext):
dw = (V + self.a - self.b * w) / self.tau
dV = V - V * V * V / 3 - w + Iext
return dV, dw
def update(self, _t, _dt, x):
V, w = self.integral(self.V, self.w, _t, x)
self.spike[:] = bm.logical_and(V > self.Vth, self.V <= self.Vth)
self.w[:] = w
self.V[:] = V
Note this model has three variables: self.V
, self.w
, and self.spike
. It also has an integrator self.integral
.
The naming system#
Every BrainPyObject object has a unique name. Users can specify a unique name when you instantiate a BrainPyObject class. A used name will cause an error.
FHN(10, name='X').name
'X'
FHN(10, name='Y').name
'Y'
try:
FHN(10, name='Y').name
except Exception as e:
print(type(e).__name__, ':', e)
UniqueNameError : In BrainPy, each object should have a unique name. However, we detect that <__main__.FHN object at 0x0000013BEA0DDF10> has a used name "Y".
If you try to run multiple trials, you may need
>>> brainpy.base.clear_name_cache()
to clear all cached names.
If a name is not specified to the BrainPyObject oject, BrainPy will assign a name for this object automatically. The rule for generating object name is class_name + number_of_instances
. For example, FHN0
, FHN1
, etc.
FHN(10).name
'FHN0'
FHN(10).name
'FHN1'
Therefore, in BrainPy, you can access any object by its unique name, no matter how insignificant this object is.
Collection functions#
Three important collection functions are implemented for each BrainPyObject object. Specifically, they are:
nodes()
: to collect all instances of BrainPyObject objects, including children nodes in a node.vars()
: to collect all variables defined in the BrainPyObject node and in its children nodes.
fhn = FHN(10)
All variables in a BrainPyObject object can be collected through BrainPyObject.vars()
. The returned container is a ArrayCollector (a subclass of Collector
).
vars = fhn.vars()
vars
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN2.spike': Variable([False, False, False, False, False, False, False, False, False,
False], dtype=bool),
'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
type(vars)
brainpy.base.collector.ArrayCollector
All nodes in the model can also be collected through one method BrainPyObject.nodes()
. The result container is an instance of Collector.
nodes = fhn.nodes()
nodes # note: integrator is also a node
{'FHN2': <__main__.FHN at 0x13bea0ddb20>,
'RK44': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea1200a0>}
type(nodes)
brainpy.base.collector.Collector
All integrators can be collected by:
ints = fhn.nodes().subset(bp.integrators.Integrator)
ints
{'RK44': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea1200a0>}
type(ints)
brainpy.base.collector.Collector
Now, let’s make a more complicated model by using the previously defined model FHN
.
class FeedForwardCircuit(bp.BrainPyObject):
def __init__(self, num1, num2, w=0.1, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
super(FeedForwardCircuit, self).__init__(name=name)
self.pre = FHN(num1, a=a, b=b, tau=tau, Vth=Vth)
self.post = FHN(num2, a=a, b=b, tau=tau, Vth=Vth)
self.conn = bm.ones((num1, num2), dtype=bool) * w
bm.fill_diagonal(self.conn, 0.)
def update(self, _t, _dt, x):
self.pre.update(_t, _dt, x)
x2 = self.pre.spike @ self.conn
self.post.update(_t, _dt, x2)
This model FeedForwardCircuit
defines two layers. Each layer is modeled as a FitzHugh-Nagumo model (FHN
). The first layer is densely connected to the second layer. The input to the second layer is the product of the first layer’s spike and the connection strength w
.
net = FeedForwardCircuit(8, 5)
We can retrieve all variables by .vars()
:
net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN3.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN4.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
'FHN4.spike': Variable([False, False, False, False, False], dtype=bool),
'FHN4.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}
And retrieve all nodes (instances of the BrainPyObject class) by .nodes()
:
net.nodes()
{'FeedForwardCircuit0': <__main__.FeedForwardCircuit at 0x13bea120dc0>,
'FHN3': <__main__.FHN at 0x13bea139eb0>,
'FHN4': <__main__.FHN at 0x13bea120430>,
'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea139a00>,
'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea13bcd0>}
If we only care about a subtype of class, we can retrieve them through:
net.nodes().subset(bp.ode.ODEIntegrator)
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea139a00>,
'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea13bcd0>}
Absolute paths#
It’s worthy to note that there are two ways to access variables, integrators, and nodes. They are “absolute” paths and “relative” paths. The default way is the absolute path.
For absolute paths, all keys in the resulting Collector (BrainPyObject.nodes()
) has the format of key = node_name [+ field_name]
.
.nodes() example 1: In the above fhn
instance, there are two nodes: “fnh” and its integrator “fhn.integral”.
fhn.integral.name, fhn.name
('RK44', 'FHN2')
Calling .nodes()
returns their names and models.
fhn.nodes().keys()
dict_keys(['FHN2', 'RK44'])
.nodes() example 2: In the above net
instance, there are five nodes:
net.pre.name, net.post.name, net.pre.integral.name, net.post.integral.name, net.name
('FHN3', 'FHN4', 'RK45', 'RK46', 'FeedForwardCircuit0')
Calling .nodes()
also returns the names and instances of all models.
net.nodes().keys()
dict_keys(['FeedForwardCircuit0', 'FHN3', 'FHN4', 'RK45', 'RK46'])
.vars() example 1: In the above fhn
instance, there are three variables: “V”, “w” and “input”. Calling .vars()
returns a dict of <node_name + var_name, var_value>
.
fhn.vars()
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN2.spike': Variable([False, False, False, False, False, False, False, False, False,
False], dtype=bool),
'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
.vars() example 2: This also applies in the net
instance:
net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN3.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'FHN4.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
'FHN4.spike': Variable([False, False, False, False, False], dtype=bool),
'FHN4.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}
Relative paths#
Variables, integrators, and nodes can also be accessed by relative paths. For example, the pre
instance in the net
can be accessed by
net.pre
<__main__.FHN at 0x13bea139eb0>
Relative paths preserve the dependence relationship. For example, all nodes retrieved from the perspective of net
are:
net.nodes(method='relative')
{'': <__main__.FeedForwardCircuit at 0x13bea120dc0>,
'pre': <__main__.FHN at 0x13bea139eb0>,
'post': <__main__.FHN at 0x13bea120430>,
'pre.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea139a00>,
'post.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea13bcd0>}
However, nodes retrieved from the start point of net.pre
will be:
net.pre.nodes('relative')
{'': <__main__.FHN at 0x13bea139eb0>,
'integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea139a00>}
Variables can also br relatively inferred from the model. For example, variables that can be relatively accessed from net
include:
net.vars('relative')
{'pre.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'pre.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
'pre.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'post.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
'post.spike': Variable([False, False, False, False, False], dtype=bool),
'post.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}
While variables relatively accessed from net.post
are:
net.post.vars('relative')
{'V': Variable([0., 0., 0., 0., 0.], dtype=float32),
'spike': Variable([False, False, False, False, False], dtype=bool),
'w': Variable([0., 0., 0., 0., 0.], dtype=float32)}
Elements in containers#
One drawback of collection functions is that they don not look for elements in list, dict or any other container structure.
class ATest(bp.BrainPyObject):
def __init__(self):
super(ATest, self).__init__()
self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
self.sub_nodes = {'a': FHN(10), 'b': FHN(5)}
t1 = ATest()
The above class defines a list of variables, and a dict of children nodes, but the variables and children nodes cannot be retrieved from the collection functions vars()
and nodes()
.
t1.vars()
{}
t1.nodes()
{'ATest0': <__main__.ATest at 0x13bea269d00>}
To solve this problem, BrianPy provides implicit_vars
and implicit_nodes
(an instance of “dict”) to hold variables and nodes in container structures. Variables registered in implicit_vars
and integrators and nodes registered in implicit_nodes
can be retrieved by collection functions.
class AnotherTest(bp.BrainPyObject):
def __init__(self):
super(AnotherTest, self).__init__()
self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
self.sub_nodes = {'a': FHN(10, name='T1'), 'b': FHN(5, name='T2')}
self.register_implicit_vars({f'v{i}': v for i, v in enumerate(self.all_vars)}) # the input must be a dict
self.register_implicit_nodes({k: v for k, v in self.sub_nodes.items()}) # the input must be a dict
t2 = AnotherTest()
# This model has two "FHN" instances, each "FHN" instance has one integrator.
# Therefore, there are five BrainPyObject objects.
t2.nodes()
{'AnotherTest0': <__main__.AnotherTest at 0x13bea271670>,
'T1': <__main__.FHN at 0x13bea271b80>,
'T2': <__main__.FHN at 0x13bea249b80>,
'RK49': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea271be0>,
'RK410': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea25d700>}
# This model has two FHN node, each of which has three variables.
# Moreover, this model has two implicit variables.
t2.vars()
{'AnotherTest0.v0': Variable([0., 0., 0., 0., 0.], dtype=float32),
'AnotherTest0.v1': Variable([1., 1., 1., 1., 1., 1.], dtype=float32),
'T1.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'T1.spike': Variable([False, False, False, False, False, False, False, False, False,
False], dtype=bool),
'T1.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'T2.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
'T2.spike': Variable([False, False, False, False, False], dtype=bool),
'T2.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}
Saving and loading#
Because BrainPyObject.vars()
returns a Python dictionary object Collector, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to BrainPy models. Therefore, each BrainPyObject object has standard exporting and loading methods (for more details, please see Saving and Loading). Specifically, they are implemented by BrainPyObject.save_states()
and BrainPyObject.load_states()
.
Save#
BrainPyObject.save_states(PATH, [vars])
Models exported from BrainPy support various Python standard file formats, including
HDF5:
.h5
,.hdf5
.npz
(NumPy file format).pkl
(Python’spickle
utility).mat
(Matlab file format)
net.save_states('./data/net.h5')
net.save_states('./data/net.pkl')
Load#
BrainPyObject.load_states(PATH)
net.load_states('./data/net.h5')
net.load_states('./data/net.pkl')
Collector#
Collection functions return an brainpy.Collector
that is a dictionary mapping names to elements. It has some useful methods.
subset()
#
Collector.subset(cls)
returns a part of elements whose type is the given cls
. For example, BrainPyObject.nodes()
returns all instances of BrainPyObject class. If you are only interested in one type, like ODEIntegrator
, you can use:
net.nodes().subset(bp.ode.ODEIntegrator)
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea139a00>,
'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea13bcd0>}
Actually, Collector.subset(cls)
travels all the elements in this collection, and find the element whose type matches the given cls
.
unique()
#
It is common in machine learning that weights are shared with several objects, or the same weight can be accessed by various dependence relationships. Collection functions of BrainPyObject usually return a collection in which the same value have multiple keys. The duplicate elements will not be automatically excluded. However, it is important not to apply operations such as gradient descent twice or more to the same elements.
Therefore, the Collector provides Collector.unique()
to handle this problem automatically. Collector.unique()
returns a copy of collection in which all elements are unique.
class ModelA(bp.BrainPyObject):
def __init__(self):
super(ModelA, self).__init__()
self.a = bm.Variable(bm.zeros(5))
class SharedA(bp.BrainPyObject):
def __init__(self, source):
super(SharedA, self).__init__()
self.source = source
self.a = source.a # shared variable
class Group(bp.BrainPyObject):
def __init__(self):
super(Group, self).__init__()
self.A = ModelA()
self.A_shared = SharedA(self.A)
g = Group()
g.vars('relative') # save Variable can be accessed by three paths
{'A.a': Variable([0., 0., 0., 0., 0.], dtype=float32),
'A_shared.a': Variable([0., 0., 0., 0., 0.], dtype=float32),
'A_shared.source.a': Variable([0., 0., 0., 0., 0.], dtype=float32)}
g.vars('relative').unique() # only return a unique path
{'A.a': Variable([0., 0., 0., 0., 0.], dtype=float32)}
g.nodes('relative') # "ModelA" is accessed twice
{'': <__main__.Group at 0x13bea5c6f40>,
'A': <__main__.ModelA at 0x13bea5c82e0>,
'A_shared': <__main__.SharedA at 0x13bea5c8550>,
'A_shared.source': <__main__.ModelA at 0x13bea5c82e0>}
g.nodes('relative').unique()
{'': <__main__.Group at 0x13bea5c6f40>,
'A': <__main__.ModelA at 0x13bea5c82e0>,
'A_shared': <__main__.SharedA at 0x13bea5c8550>}
update()
#
The Collector can also catch potential conflicts during the assignment. The bracket assignment of a Collector ([key]
) and Collector.update()
will check whether the same key is mapped to a different value. If it is, an error will occur.
tc = bp.Collector({'a': bm.zeros(10)})
tc
{'a': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
try:
tc['a'] = bm.zeros(1) # same key "a", different tensor
except Exception as e:
print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [0.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].
try:
tc.update({'a': bm.ones(1)}) # same key "a", different tensor
except Exception as e:
print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [1.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].
replace()
#
Collector.replace(old_key, new_value)
is used to update the value of a key.
tc
{'a': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
tc.replace('a', bm.ones(3))
tc
{'a': Array([1., 1., 1.], dtype=float32)}
__add()__
#
Two Collectors can be merged.
a = bp.Collector({'a': bm.zeros(10)})
b = bp.Collector({'b': bm.ones(10)})
a + b
{'a': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
'b': Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)}
ArrayCollector#
ArrayCollector
is subclass of Collector
, but it is specifically to collect arrays.