Base Class
In this section, we are going to talk about:
Base
class for BrainPy ecosystem,Collector
to facilitate variable collection and manipulation.
import brainpy as bp
import brainpy.math as bm
Base
The foundation of BrainPy is brainpy.Base. A Base instance is an object which has variables and methods. All methods in the Base object can be JIT compiled or automatic differentiated. Or we can say, any class objects want to JIT compile or auto differentiate must inherent from brainpy.Base
.
A Base object can have many variables, children Base objects, integrators, and methods. For example, let’s implement a FitzHugh-Nagumo neuron model.
class FHN(bp.Base):
def __init__(self, num, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
super(FHN, self).__init__(name=name)
# parameters
self.num = num
self.a = a
self.b = b
self.tau = tau
self.Vth = Vth
# variables
self.V = bm.Variable(bm.zeros(num))
self.w = bm.Variable(bm.zeros(num))
self.spike = bm.Variable(bm.zeros(num, dtype=bool))
# integral
self.integral = bp.odeint(method='rk4', f=self.derivative)
def derivative(self, V, w, t, Iext):
dw = (V + self.a - self.b * w) / self.tau
dV = V - V * V * V / 3 - w + Iext
return dV, dw
def update(self, _t, _dt, x):
V, w = self.integral(self.V, self.w, _t, x)
self.spike[:] = bm.logical_and(V > self.Vth, self.V <= self.Vth)
self.w[:] = w
self.V[:] = V
Note this model has three variables: self.V
, self.w
, and self.spike
. It also has an integrator self.integral
.
Naming system
Every Base object has a unique name. You can specify a unique name when you instantiate a Base class. A used name will cause an error.
FHN(10, name='X').name
'X'
FHN(10, name='Y').name
'Y'
try:
FHN(10, name='Y').name
except Exception as e:
print(type(e).__name__, ':', e)
UniqueNameError : In BrainPy, each object should have a unique name. However, we detect that <__main__.FHN object at 0x7f4a7406bd60> has a used name "Y".
When you instance a Base class without “name” specification, BrainPy will assign a name for this object automatically. The rule for generating object name is class_name + number_of_instances
. For example, FHN0
, FHN1
, etc.
FHN(10).name
'FHN0'
FHN(10).name
'FHN1'
Therefore in BrainPy, you can access any object by its unique name, no matter how insignificant this object is.
Collection functions
Three important collection functions are implemented for each Base object. Specifically, they are:
nodes()
: to collect all instances of Base objects, including children nodes in a node.ints()
: to collect all integrators defined in the Base node and in its children nodes.vars()
: to collect all variables defined in the Base node and in its children nodes.
All integrators can be collected through one method Base.ints()
. The result container is a Collector.
fhn = FHN(10)
ints = fhn.ints()
ints
{'FHN2.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7406b430>}
type(ints)
brainpy.base.collector.Collector
Similarly, all variables in a Base object can be collected through Base.vars()
. The returned container is a TensorCollector (a subclass of Collector
).
vars = fhn.vars()
vars
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
'FHN2.spike': Variable([False, False, False, False, False, False, False, False, False,
False]),
'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}
type(vars)
brainpy.base.collector.TensorCollector
All nodes in the model can also be collected through one method Base.nodes()
. The result container is an instance of Collector.
nodes = fhn.nodes()
nodes # note: integrator is also a node
{'RK44': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7406b430>,
'FHN2': <__main__.FHN at 0x7f4a7406b5e0>}
type(nodes)
brainpy.base.collector.Collector
Now, let’s make a more complicated model by using the previously defined model FHN
.
class FeedForwardCircuit(bp.Base):
def __init__(self, num1, num2, w=0.1, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
super(FeedForwardCircuit, self).__init__(name=name)
self.pre = FHN(num1, a=a, b=b, tau=tau, Vth=Vth)
self.post = FHN(num2, a=a, b=b, tau=tau, Vth=Vth)
conn = bm.ones((num1, num2), dtype=bool)
self.conn = bm.fill_diagonal(conn, False) * w
def update(self, _t, _dt, x):
self.pre.update(_t, _dt, x)
x2 = self.pre.spike @ self.conn
self.post.update(_t, _dt, x2)
This model FeedForwardCircuit
defines two layers. Each layer is modeled as a FitzHugh-Nagumo model (FHN
). The first layer is densely connected to the second layer. The input to the second layer is the first layer’s spike times a connection strength w
.
net = FeedForwardCircuit(8, 5)
We can retrieve all integrators in the network with .ints()
:
net.ints()
{'FHN3.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>,
'FHN4.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7401b100>}
Or, retrieve all variables by .vars()
:
net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
'FHN3.spike': Variable([False, False, False, False, False, False, False, False]),
'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
'FHN4.V': Variable([0., 0., 0., 0., 0.]),
'FHN4.spike': Variable([False, False, False, False, False]),
'FHN4.w': Variable([0., 0., 0., 0., 0.])}
Or, retrieve all nodes (instances of Base class) with .nodes()
:
net.nodes()
{'FHN3': <__main__.FHN at 0x7f4a74077670>,
'FHN4': <__main__.FHN at 0x7f4a740771c0>,
'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>,
'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7401b100>,
'FeedForwardCircuit0': <__main__.FeedForwardCircuit at 0x7f4a74077790>}
Absolute path
It’s worthy to note that there are two types of ways to access variables, integrators, and nodes. They are “absolute” path and “relative” path. The default way is the absolute path.
“Absolute” path means that all keys in the resulting Collector (Base.nodes()
) has the format of key = node_name [+ field_name]
.
.nodes() example 1: In the above fhn
instance, there are two nodes: “fnh” and its integrator “fhn.integral”.
fhn.integral.name, fhn.name
('RK44', 'FHN2')
Calling .nodes()
returns models’ name and models.
fhn.nodes().keys()
dict_keys(['RK44', 'FHN2'])
.nodes() example 2: In the above net
instance, there are five nodes:
net.pre.name, net.post.name, net.pre.integral.name, net.post.integral.name, net.name
('FHN3', 'FHN4', 'RK45', 'RK46', 'FeedForwardCircuit0')
Calling .nodes()
also returns the names and instances of all models.
net.nodes().keys()
dict_keys(['FHN3', 'FHN4', 'RK45', 'RK46', 'FeedForwardCircuit0'])
.vars() example 1: In the above fhn
instance, there are three variables: “V”, “w” and “input”. Calling .vars()
returns a dict of <node_name + var_name, var_value>
.
fhn.vars()
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
'FHN2.spike': Variable([False, False, False, False, False, False, False, False, False,
False]),
'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}
.vars() example 2: This also applies in the net
instance:
net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
'FHN3.spike': Variable([False, False, False, False, False, False, False, False]),
'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
'FHN4.V': Variable([0., 0., 0., 0., 0.]),
'FHN4.spike': Variable([False, False, False, False, False]),
'FHN4.w': Variable([0., 0., 0., 0., 0.])}
Relative path
Variables, integrators, and nodes can also be accessed by relative path. For example, the pre
instance in the net
can be accessed by
net.pre
<__main__.FHN at 0x7f4a74077670>
Relative path preserves the dependence relationship. For example, all nodes retrieved from the perspective of net
are:
net.nodes(method='relative')
{'': <__main__.FeedForwardCircuit at 0x7f4a74077790>,
'pre': <__main__.FHN at 0x7f4a74077670>,
'post': <__main__.FHN at 0x7f4a740771c0>,
'pre.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>,
'post.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7401b100>}
However, nodes retrieved from the start point of net.pre
will be:
net.pre.nodes('relative')
{'': <__main__.FHN at 0x7f4a74077670>,
'integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>}
Variables can also relatively inferred from the model. For example, all variables one can relatively accessed from net
are:
net.vars('relative')
{'pre.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
'pre.spike': Variable([False, False, False, False, False, False, False, False]),
'pre.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
'post.V': Variable([0., 0., 0., 0., 0.]),
'post.spike': Variable([False, False, False, False, False]),
'post.w': Variable([0., 0., 0., 0., 0.])}
While, variables relatively accessed from the view of net.post
are:
net.post.vars('relative')
{'V': Variable([0., 0., 0., 0., 0.]),
'spike': Variable([False, False, False, False, False]),
'w': Variable([0., 0., 0., 0., 0.])}
Elements in containers
To avoid surprising unintended behaviors, collection functions don’t look for elements in list, dict or any container structure.
class ATest(bp.Base):
def __init__(self):
super(ATest, self).__init__()
self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
self.sub_nodes = {'a': FHN(10), 'b': FHN(5)}
t1 = ATest()
The above class define a list of variables, and a dict of children nodes. However, they can not be retrieved from the collection functions vars()
and nodes()
.
t1.vars()
{}
t1.nodes()
{'ATest0': <__main__.ATest at 0x7f4a7401b2b0>}
Fortunately, in BrianPy, we provide implicit_vars
and implicit_nodes
(an instance of “dict”) to hold variables and nodes in container structures. Any variable registered in implicit_vars
, or any integrator or node registered in implicit_nodes
can be retrieved by collection functions. Let’s make a try.
class AnotherTest(bp.Base):
def __init__(self):
super(AnotherTest, self).__init__()
self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
self.sub_nodes = {'a': FHN(10, name='T1'), 'b': FHN(5, name='T2')}
self.implicit_vars = {f'v{i}': v for i, v in enumerate(self.all_vars)} # must be a dict
self.implicit_nodes = {k: v for k, v in self.sub_nodes.items()} # must be a dict
t2 = AnotherTest()
# This model has two "FHN" instances, each "FHN" instance has one integrator.
# Therefore, there are five Base objects.
t2.nodes()
{'T1': <__main__.FHN at 0x7f4a740777c0>,
'T2': <__main__.FHN at 0x7f4a74077a90>,
'RK49': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74077340>,
'RK410': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74011040>,
'AnotherTest0': <__main__.AnotherTest at 0x7f4a740157f0>}
# This model has five Base objects (seen above),
# each FHN node has three variables,
# moreover, this model has two implicit variables.
t2.vars()
{'T1.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
'T1.spike': Variable([False, False, False, False, False, False, False, False, False,
False]),
'T1.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
'T2.V': Variable([0., 0., 0., 0., 0.]),
'T2.spike': Variable([False, False, False, False, False]),
'T2.w': Variable([0., 0., 0., 0., 0.]),
'AnotherTest0.v0': Variable([0., 0., 0., 0., 0.]),
'AnotherTest0.v1': Variable([1., 1., 1., 1., 1., 1.])}
Saving and loading
Because Base.vars()
returns a Python dictionary object Collector, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to BrainPy models. Therefore, each Base object has standard exporting and loading methods (more details please see Saving and Loading). Specifically, they are implemented by Base.save_states()
and Base.load_states()
.
Save
Base.save_states(PATH, [vars])
Model exporting in BrainPy supports various Python standard file formats, including
HDF5:
.h5
,.hdf5
.npz
(NumPy file format).pkl
(Python’spickle
utility).mat
(Matlab file format)
net.save_states('./data/net.h5')
net.save_states('./data/net.pkl')
# Unknown file format will cause error
try:
net.save_states('./data/net.xxx')
except Exception as e:
print(type(e).__name__, ":", e)
BrainPyError : Unknown file format: ./data/net.xxx. We only supports ['.h5', '.hdf5', '.npz', '.pkl', '.mat']
Load
Base.load_states(PATH)
net.load_states('./data/net.h5')
net.load_states('./data/net.pkl')
Collector
Collection functions returns an brainpy.Collector
. This class is a dictionary that maps names to elements. It has some useful methods.
subset()
Collector.subset(cls)
returns a part of elements whose type is the given cls
. For example, Base.nodes()
returns all instances of Base class. If you are only interested in one type, like ODEIntegrator
, you can use:
net.nodes().subset(bp.ode.ODEIntegrator)
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>,
'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7401b100>}
Actually, Collector.subset(cls)
travels all the elements in this collection, and find the element whose type matches to the given cls
.
unique()
It’s a common in machine learning that weights are shared with several objects, or the same weight can be accessed by various dependence relationships. Collection functions of Base usually return a collection in which the same value have multiple keys. The duplicate elements will not be automatically excluded. However, it is important not to apply operations twice or more to the same elements (e.g., apply gradients and update weights).
Therefore, Collector provides method Collector.unique()
to handle this automatically. Collector.unique()
returns a copy of collection in which all elements are unique.
class ModelA(bp.Base):
def __init__(self):
super(ModelA, self).__init__()
self.a = bm.Variable(bm.zeros(5))
class SharedA(bp.Base):
def __init__(self, source):
super(SharedA, self).__init__()
self.source = source
self.a = source.a # shared variable
class Group(bp.Base):
def __init__(self):
super(Group, self).__init__()
self.A = ModelA()
self.A_shared = SharedA(self.A)
g = Group()
g.vars('relative') # save Variable can be accessed by three paths
{'A.a': Variable([0., 0., 0., 0., 0.]),
'A_shared.a': Variable([0., 0., 0., 0., 0.]),
'A_shared.source.a': Variable([0., 0., 0., 0., 0.])}
g.vars('relative').unique() # only return a unique path
{'A.a': Variable([0., 0., 0., 0., 0.])}
g.nodes('relative') # "ModelA" is accessed twice
{'': <__main__.Group at 0x7f4a74049190>,
'A': <__main__.ModelA at 0x7f4a740490a0>,
'A_shared': <__main__.SharedA at 0x7f4a74049550>,
'A_shared.source': <__main__.ModelA at 0x7f4a740490a0>}
g.nodes('relative').unique()
{'': <__main__.Group at 0x7f4a74049190>,
'A': <__main__.ModelA at 0x7f4a740490a0>,
'A_shared': <__main__.SharedA at 0x7f4a74049550>}
update()
Collector is a dict. But, it has means to catch potential conflicts during assignment. The bracket assignment of a Collector ([key]
) and Collector.update()
will check whether the same key maps to a different value. If yes, an error will raise.
tc = bp.Collector({'a': bm.zeros(10)})
tc
{'a': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}
try:
tc['a'] = bm.zeros(1) # same key "a", different tensor
except Exception as e:
print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [0.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].
try:
tc.update({'a': bm.ones(1)}) # same key "a", different tensor
except Exception as e:
print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [1.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].
replace()
If you try to replace the old key with the new value, you should use Collector.replace(old_key, new_value)
function.
tc
{'a': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}
tc.replace('a', bm.ones(3))
tc
{'a': array([1., 1., 1.])}
TensorCollector
TensorCollector
is subclass of Collector
, but it is specifically to collect tensors.