BrainPyObject and Collector#

@Chaoming Wang @Xiaoyu Chen

In this section, we are going to talk about:

  • The BrainPyObject class for the BrainPy ecosystem

  • The Collector to facilitate variable collection and manipulation.

import brainpy as bp
import brainpy.math as bm

# bm.set_platform('cpu')

brainpy.BrainPyObject#

The foundation of BrainPy is brainpy.BrainPyObject. A BrainPyObject instance is an object which has variables and methods. All methods in the BrainPyObject object can be JIT compiled or automatically differentiated. In other words, any class objects that will be JIT compiled or automatically differentiated must inherent from brainpy.BrainPyObject.

A BrainPyObject object can have many variables, children BrainPyObject objects, integrators, and methods. Below is the implemention of a FitzHugh-Nagumo neuron model as an example.

class FHN(bp.BrainPyObject):
  def __init__(self, num, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
    super(FHN, self).__init__(name=name)

    # parameters
    self.num = num
    self.a = a
    self.b = b
    self.tau = tau
    self.Vth = Vth

    # variables
    self.V = bm.Variable(bm.zeros(num))
    self.w = bm.Variable(bm.zeros(num))
    self.spike = bm.Variable(bm.zeros(num, dtype=bool))

    # integral
    self.integral = bp.odeint(method='rk4', f=self.derivative)

  def derivative(self, V, w, t, Iext):
    dw = (V + self.a - self.b * w) / self.tau
    dV = V - V * V * V / 3 - w + Iext
    return dV, dw

  def update(self, _t, _dt, x):
    V, w = self.integral(self.V, self.w, _t, x)
    self.spike[:] = bm.logical_and(V > self.Vth, self.V <= self.Vth)
    self.w[:] = w
    self.V[:] = V

Note this model has three variables: self.V, self.w, and self.spike. It also has an integrator self.integral.

The naming system#

Every BrainPyObject object has a unique name. Users can specify a unique name when you instantiate a BrainPyObject class. A used name will cause an error.

FHN(10, name='X').name
'X'
FHN(10, name='Y').name
'Y'
try:
    FHN(10, name='Y').name
except Exception as e:
    print(type(e).__name__, ':', e)
UniqueNameError : In BrainPy, each object should have a unique name. However, we detect that <__main__.FHN object at 0x0000013BEA0DDF10> has a used name "Y". 
If you try to run multiple trials, you may need 

>>> brainpy.base.clear_name_cache() 

to clear all cached names. 

If a name is not specified to the BrainPyObject oject, BrainPy will assign a name for this object automatically. The rule for generating object name is class_name +  number_of_instances. For example, FHN0, FHN1, etc.

FHN(10).name
'FHN0'
FHN(10).name
'FHN1'

Therefore, in BrainPy, you can access any object by its unique name, no matter how insignificant this object is.

Collection functions#

Three important collection functions are implemented for each BrainPyObject object. Specifically, they are:

  • nodes(): to collect all instances of BrainPyObject objects, including children nodes in a node.

  • vars(): to collect all variables defined in the BrainPyObject node and in its children nodes.

fhn = FHN(10)

All variables in a BrainPyObject object can be collected through BrainPyObject.vars(). The returned container is a ArrayCollector (a subclass of Collector).

vars = fhn.vars()

vars
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN2.spike': Variable([False, False, False, False, False, False, False, False, False,
           False], dtype=bool),
 'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
type(vars)
brainpy.base.collector.ArrayCollector

All nodes in the model can also be collected through one method BrainPyObject.nodes(). The result container is an instance of Collector.

nodes = fhn.nodes()

nodes  # note: integrator is also a node
{'FHN2': <__main__.FHN at 0x13bea0ddb20>,
 'RK44': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea1200a0>}
type(nodes)
brainpy.base.collector.Collector

All integrators can be collected by:

ints = fhn.nodes().subset(bp.integrators.Integrator)

ints
{'RK44': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea1200a0>}
type(ints)
brainpy.base.collector.Collector

Now, let’s make a more complicated model by using the previously defined model FHN.

class FeedForwardCircuit(bp.BrainPyObject):
    def __init__(self, num1, num2, w=0.1, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
        super(FeedForwardCircuit, self).__init__(name=name)
        
        self.pre = FHN(num1, a=a, b=b, tau=tau, Vth=Vth)
        self.post = FHN(num2, a=a, b=b, tau=tau, Vth=Vth)
        
        self.conn = bm.ones((num1, num2), dtype=bool) * w
        bm.fill_diagonal(self.conn, 0.)

    def update(self, _t, _dt, x):
        self.pre.update(_t, _dt, x)
        x2 = self.pre.spike @ self.conn
        self.post.update(_t, _dt, x2)

This model FeedForwardCircuit defines two layers. Each layer is modeled as a FitzHugh-Nagumo model (FHN). The first layer is densely connected to the second layer. The input to the second layer is the product of the first layer’s spike and the connection strength w.

net = FeedForwardCircuit(8, 5)

We can retrieve all variables by .vars():

net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN3.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
 'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN4.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'FHN4.spike': Variable([False, False, False, False, False], dtype=bool),
 'FHN4.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}

And retrieve all nodes (instances of the BrainPyObject class) by .nodes():

net.nodes()
{'FeedForwardCircuit0': <__main__.FeedForwardCircuit at 0x13bea120dc0>,
 'FHN3': <__main__.FHN at 0x13bea139eb0>,
 'FHN4': <__main__.FHN at 0x13bea120430>,
 'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea139a00>,
 'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea13bcd0>}

If we only care about a subtype of class, we can retrieve them through:

net.nodes().subset(bp.ode.ODEIntegrator)
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea139a00>,
 'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea13bcd0>}

Absolute paths#

It’s worthy to note that there are two ways to access variables, integrators, and nodes. They are “absolute” paths and “relative” paths. The default way is the absolute path.

For absolute paths, all keys in the resulting Collector (BrainPyObject.nodes()) has the format of key = node_name [+ field_name].

.nodes() example 1: In the above fhn instance, there are two nodes: “fnh” and its integrator “fhn.integral”.

fhn.integral.name, fhn.name
('RK44', 'FHN2')

Calling .nodes() returns their names and models.

fhn.nodes().keys()
dict_keys(['FHN2', 'RK44'])

.nodes() example 2: In the above net instance, there are five nodes:

net.pre.name, net.post.name, net.pre.integral.name, net.post.integral.name, net.name
('FHN3', 'FHN4', 'RK45', 'RK46', 'FeedForwardCircuit0')

Calling .nodes() also returns the names and instances of all models.

net.nodes().keys()
dict_keys(['FeedForwardCircuit0', 'FHN3', 'FHN4', 'RK45', 'RK46'])

.vars() example 1: In the above fhn instance, there are three variables: “V”, “w” and “input”. Calling .vars() returns a dict of <node_name + var_name, var_value>.

fhn.vars()
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN2.spike': Variable([False, False, False, False, False, False, False, False, False,
           False], dtype=bool),
 'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}

.vars() example 2: This also applies in the net instance:

net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN3.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
 'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN4.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'FHN4.spike': Variable([False, False, False, False, False], dtype=bool),
 'FHN4.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}

Relative paths#

Variables, integrators, and nodes can also be accessed by relative paths. For example, the pre instance in the net can be accessed by

net.pre
<__main__.FHN at 0x13bea139eb0>

Relative paths preserve the dependence relationship. For example, all nodes retrieved from the perspective of net are:

net.nodes(method='relative')
{'': <__main__.FeedForwardCircuit at 0x13bea120dc0>,
 'pre': <__main__.FHN at 0x13bea139eb0>,
 'post': <__main__.FHN at 0x13bea120430>,
 'pre.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea139a00>,
 'post.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea13bcd0>}

However, nodes retrieved from the start point of net.pre will be:

net.pre.nodes('relative')
{'': <__main__.FHN at 0x13bea139eb0>,
 'integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea139a00>}

Variables can also br relatively inferred from the model. For example, variables that can be relatively accessed from net include:

net.vars('relative')
{'pre.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'pre.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
 'pre.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'post.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'post.spike': Variable([False, False, False, False, False], dtype=bool),
 'post.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}

While variables relatively accessed from net.post are:

net.post.vars('relative')
{'V': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'spike': Variable([False, False, False, False, False], dtype=bool),
 'w': Variable([0., 0., 0., 0., 0.], dtype=float32)}

Elements in containers#

One drawback of collection functions is that they don not look for elements in list, dict or any other container structure.

class ATest(bp.BrainPyObject):
    def __init__(self):
        super(ATest, self).__init__()
        
        self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
        self.sub_nodes = {'a': FHN(10), 'b': FHN(5)}
t1 = ATest()

The above class defines a list of variables, and a dict of children nodes, but the variables and children nodes cannot be retrieved from the collection functions vars() and nodes().

t1.vars()
{}
t1.nodes()
{'ATest0': <__main__.ATest at 0x13bea269d00>}

To solve this problem, BrianPy provides implicit_vars and implicit_nodes (an instance of “dict”) to hold variables and nodes in container structures. Variables registered in implicit_vars and integrators and nodes registered in implicit_nodes can be retrieved by collection functions.

class AnotherTest(bp.BrainPyObject):
    def __init__(self):
        super(AnotherTest, self).__init__()
        
        self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
        self.sub_nodes = {'a': FHN(10, name='T1'), 'b': FHN(5, name='T2')}
        
        self.register_implicit_vars({f'v{i}': v for i, v in enumerate(self.all_vars)})  # the input must be a dict
        self.register_implicit_nodes({k: v for k, v in self.sub_nodes.items()})  # the input must be a dict
t2 = AnotherTest()
# This model has two "FHN" instances, each "FHN" instance has one integrator. 
# Therefore, there are five BrainPyObject objects. 

t2.nodes()
{'AnotherTest0': <__main__.AnotherTest at 0x13bea271670>,
 'T1': <__main__.FHN at 0x13bea271b80>,
 'T2': <__main__.FHN at 0x13bea249b80>,
 'RK49': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea271be0>,
 'RK410': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea25d700>}
# This model has two FHN node, each of which has three variables.
# Moreover, this model has two implicit variables.

t2.vars()
{'AnotherTest0.v0': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'AnotherTest0.v1': Variable([1., 1., 1., 1., 1., 1.], dtype=float32),
 'T1.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'T1.spike': Variable([False, False, False, False, False, False, False, False, False,
           False], dtype=bool),
 'T1.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'T2.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'T2.spike': Variable([False, False, False, False, False], dtype=bool),
 'T2.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}

Saving and loading#

Because BrainPyObject.vars() returns a Python dictionary object Collector, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to BrainPy models. Therefore, each BrainPyObject object has standard exporting and loading methods (for more details, please see Saving and Loading). Specifically, they are implemented by BrainPyObject.save_states() and BrainPyObject.load_states().

Save#

BrainPyObject.save_states(PATH, [vars])

Models exported from BrainPy support various Python standard file formats, including

  • HDF5: .h5, .hdf5

  • .npz (NumPy file format)

  • .pkl (Python’s pickle utility)

  • .mat (Matlab file format)

net.save_states('./data/net.h5')
net.save_states('./data/net.pkl')

Load#


BrainPyObject.load_states(PATH)
net.load_states('./data/net.h5')
net.load_states('./data/net.pkl')

Collector#

Collection functions return an brainpy.Collector that is a dictionary mapping names to elements. It has some useful methods.

subset()#

Collector.subset(cls) returns a part of elements whose type is the given cls. For example, BrainPyObject.nodes() returns all instances of BrainPyObject class. If you are only interested in one type, like ODEIntegrator, you can use:

net.nodes().subset(bp.ode.ODEIntegrator)
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea139a00>,
 'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x13bea13bcd0>}

Actually, Collector.subset(cls) travels all the elements in this collection, and find the element whose type matches the given cls.

unique()#

It is common in machine learning that weights are shared with several objects, or the same weight can be accessed by various dependence relationships. Collection functions of BrainPyObject usually return a collection in which the same value have multiple keys. The duplicate elements will not be automatically excluded. However, it is important not to apply operations such as gradient descent twice or more to the same elements.

Therefore, the Collector provides Collector.unique() to handle this problem automatically. Collector.unique() returns a copy of collection in which all elements are unique.

class ModelA(bp.BrainPyObject):
    def __init__(self):
        super(ModelA, self).__init__()
        self.a = bm.Variable(bm.zeros(5))

        
class SharedA(bp.BrainPyObject):
    def __init__(self, source):
        super(SharedA, self).__init__()
        self.source = source
        self.a = source.a  # shared variable
        
        
class Group(bp.BrainPyObject):
    def __init__(self):
        super(Group, self).__init__()
        self.A = ModelA()
        self.A_shared = SharedA(self.A)

g = Group()
g.vars('relative')  # save Variable can be accessed by three paths
{'A.a': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'A_shared.a': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'A_shared.source.a': Variable([0., 0., 0., 0., 0.], dtype=float32)}
g.vars('relative').unique()  # only return a unique path
{'A.a': Variable([0., 0., 0., 0., 0.], dtype=float32)}
g.nodes('relative')  # "ModelA" is accessed twice
{'': <__main__.Group at 0x13bea5c6f40>,
 'A': <__main__.ModelA at 0x13bea5c82e0>,
 'A_shared': <__main__.SharedA at 0x13bea5c8550>,
 'A_shared.source': <__main__.ModelA at 0x13bea5c82e0>}
g.nodes('relative').unique()
{'': <__main__.Group at 0x13bea5c6f40>,
 'A': <__main__.ModelA at 0x13bea5c82e0>,
 'A_shared': <__main__.SharedA at 0x13bea5c8550>}

update()#

The Collector can also catch potential conflicts during the assignment. The bracket assignment of a Collector ([key]) and Collector.update() will check whether the same key is mapped to a different value. If it is, an error will occur.

tc = bp.Collector({'a': bm.zeros(10)})

tc
{'a': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
try:
    tc['a'] = bm.zeros(1)  # same key "a", different tensor
except Exception as e:
    print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [0.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].
try:
    tc.update({'a': bm.ones(1)})  # same key "a", different tensor
except Exception as e:
    print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [1.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].

replace()#

Collector.replace(old_key, new_value) is used to update the value of a key.

tc
{'a': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
tc.replace('a', bm.ones(3))

tc
{'a': Array([1., 1., 1.], dtype=float32)}

__add()__#

Two Collectors can be merged.

a = bp.Collector({'a': bm.zeros(10)})
b = bp.Collector({'b': bm.ones(10)})

a + b
{'a': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'b': Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)}

ArrayCollector#

ArrayCollector is subclass of Collector, but it is specifically to collect arrays.