Base Class

@Chaoming Wang @Xiaoyu Chen

In this section, we are going to talk about:

  • The Base class for the BrainPy ecosystem

  • The Collector to facilitate variable collection and manipulation.

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

brainpy.Base

The foundation of BrainPy is brainpy.Base. A Base instance is an object which has variables and methods. All methods in the Base object can be JIT compiled or automatically differentiated. In other words, any class objects that will be JIT compiled or automatically differentiated must inherent from brainpy.Base.

A Base object can have many variables, children Base objects, integrators, and methods. Below is the implemention of a FitzHugh-Nagumo neuron model as an example.

class FHN(bp.Base):
  def __init__(self, num, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
    super(FHN, self).__init__(name=name)

    # parameters
    self.num = num
    self.a = a
    self.b = b
    self.tau = tau
    self.Vth = Vth

    # variables
    self.V = bm.Variable(bm.zeros(num))
    self.w = bm.Variable(bm.zeros(num))
    self.spike = bm.Variable(bm.zeros(num, dtype=bool))

    # integral
    self.integral = bp.odeint(method='rk4', f=self.derivative)

  def derivative(self, V, w, t, Iext):
    dw = (V + self.a - self.b * w) / self.tau
    dV = V - V * V * V / 3 - w + Iext
    return dV, dw

  def update(self, _t, _dt, x):
    V, w = self.integral(self.V, self.w, _t, x)
    self.spike[:] = bm.logical_and(V > self.Vth, self.V <= self.Vth)
    self.w[:] = w
    self.V[:] = V

Note this model has three variables: self.V, self.w, and self.spike. It also has an integrator self.integral.

The naming system

Every Base object has a unique name. Users can specify a unique name when you instantiate a Base class. A used name will cause an error.

FHN(10, name='X').name
'X'
FHN(10, name='Y').name
'Y'
try:
    FHN(10, name='Y').name
except Exception as e:
    print(type(e).__name__, ':', e)
UniqueNameError : In BrainPy, each object should have a unique name. However, we detect that <__main__.FHN object at 0x00000224FA317BB0> has a used name "Y".

If a name is not specified to the Base oject, BrainPy will assign a name for this object automatically. The rule for generating object name is class_name +  number_of_instances. For example, FHN0, FHN1, etc.

FHN(10).name
'FHN0'
FHN(10).name
'FHN1'

Therefore, in BrainPy, you can access any object by its unique name, no matter how insignificant this object is.

Collection functions

Three important collection functions are implemented for each Base object. Specifically, they are:

  • nodes(): to collect all instances of Base objects, including children nodes in a node.

  • ints(): to collect all integrators defined in the Base node and in its children nodes.

  • vars(): to collect all variables defined in the Base node and in its children nodes.

All integrators can be collected through one method Base.ints(). The result container is a Collector.

fhn = FHN(10)
ints = fhn.ints()

ints
{'FHN2.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fa309370>}
type(ints)
brainpy.base.collector.Collector

Similarly, all variables in a Base object can be collected through Base.vars(). The returned container is a TensorCollector (a subclass of Collector).

vars = fhn.vars()

vars
{'FHN2.V': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'FHN2.spike': Variable(DeviceArray([False, False, False, False, False, False, False, False,
                       False, False], dtype=bool)),
 'FHN2.w': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))}
type(vars)
brainpy.base.collector.TensorCollector

All nodes in the model can also be collected through one method Base.nodes(). The result container is an instance of Collector.

nodes = fhn.nodes()

nodes  # note: integrator is also a node
{'RK44': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fa309370>,
 'FHN2': <__main__.FHN at 0x224fa317a60>}
type(nodes)
brainpy.base.collector.Collector

Now, let’s make a more complicated model by using the previously defined model FHN.

class FeedForwardCircuit(bp.Base):
    def __init__(self, num1, num2, w=0.1, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
        super(FeedForwardCircuit, self).__init__(name=name)
        
        self.pre = FHN(num1, a=a, b=b, tau=tau, Vth=Vth)
        self.post = FHN(num2, a=a, b=b, tau=tau, Vth=Vth)
        
        conn = bm.ones((num1, num2), dtype=bool)
        self.conn = bm.fill_diagonal(conn, False) * w

    def update(self, _t, _dt, x):
        self.pre.update(_t, _dt, x)
        x2 = self.pre.spike @ self.conn
        self.post.update(_t, _dt, x2)

This model FeedForwardCircuit defines two layers. Each layer is modeled as a FitzHugh-Nagumo model (FHN). The first layer is densely connected to the second layer. The input to the second layer is the product of the first layer’s spike and the connection strength w.

net = FeedForwardCircuit(8, 5)

We can retrieve all integrators in the network by .ints() :

net.ints()
{'FHN3.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fa358790>,
 'FHN4.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fb878d90>}

Retrieve all variables by .vars():

net.vars()
{'FHN3.V': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'FHN3.spike': Variable(DeviceArray([False, False, False, False, False, False, False, False], dtype=bool)),
 'FHN3.w': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'FHN4.V': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'FHN4.spike': Variable(DeviceArray([False, False, False, False, False], dtype=bool)),
 'FHN4.w': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))}

And retrieve all nodes (instances of the Base class) by .nodes():

net.nodes()
{'FHN3': <__main__.FHN at 0x224fb8780a0>,
 'FHN4': <__main__.FHN at 0x224fa3173a0>,
 'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fa358790>,
 'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fb878d90>,
 'FeedForwardCircuit0': <__main__.FeedForwardCircuit at 0x224fb878070>}

If we only care about a subtype of class, we can retrieve them through:

net.nodes().subset(bp.ode.ODEIntegrator)
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fa358790>,
 'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fb878d90>}

Absolute paths

It’s worthy to note that there are two ways to access variables, integrators, and nodes. They are “absolute” paths and “relative” paths. The default way is the absolute path.

For absolute paths, all keys in the resulting Collector (Base.nodes()) has the format of key = node_name [+ field_name].

.nodes() example 1: In the above fhn instance, there are two nodes: “fnh” and its integrator “fhn.integral”.

fhn.integral.name, fhn.name
('RK44', 'FHN2')

Calling .nodes() returns their names and models.

fhn.nodes().keys()
dict_keys(['RK44', 'FHN2'])

.nodes() example 2: In the above net instance, there are five nodes:

net.pre.name, net.post.name, net.pre.integral.name, net.post.integral.name, net.name
('FHN3', 'FHN4', 'RK45', 'RK46', 'FeedForwardCircuit0')

Calling .nodes() also returns the names and instances of all models.

net.nodes().keys()
dict_keys(['FHN3', 'FHN4', 'RK45', 'RK46', 'FeedForwardCircuit0'])

.vars() example 1: In the above fhn instance, there are three variables: “V”, “w” and “input”. Calling .vars() returns a dict of <node_name + var_name, var_value>.

fhn.vars()
{'FHN2.V': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'FHN2.spike': Variable(DeviceArray([False, False, False, False, False, False, False, False,
                       False, False], dtype=bool)),
 'FHN2.w': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))}

.vars() example 2: This also applies in the net instance:

net.vars()
{'FHN3.V': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'FHN3.spike': Variable(DeviceArray([False, False, False, False, False, False, False, False], dtype=bool)),
 'FHN3.w': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'FHN4.V': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'FHN4.spike': Variable(DeviceArray([False, False, False, False, False], dtype=bool)),
 'FHN4.w': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))}

Relative paths

Variables, integrators, and nodes can also be accessed by relative paths. For example, the pre instance in the net can be accessed by

net.pre
<__main__.FHN at 0x224fb8780a0>

Relative paths preserve the dependence relationship. For example, all nodes retrieved from the perspective of net are:

net.nodes(method='relative')
{'': <__main__.FeedForwardCircuit at 0x224fb878070>,
 'pre': <__main__.FHN at 0x224fb8780a0>,
 'post': <__main__.FHN at 0x224fa3173a0>,
 'pre.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fa358790>,
 'post.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fb878d90>}

However, nodes retrieved from the start point of net.pre will be:

net.pre.nodes('relative')
{'': <__main__.FHN at 0x224fb8780a0>,
 'integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fa358790>}

Variables can also br relatively inferred from the model. For example, variables that can be relatively accessed from net include:

net.vars('relative')
{'pre.V': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'pre.spike': Variable(DeviceArray([False, False, False, False, False, False, False, False], dtype=bool)),
 'pre.w': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'post.V': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'post.spike': Variable(DeviceArray([False, False, False, False, False], dtype=bool)),
 'post.w': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))}

While variables relatively accessed from net.post are:

net.post.vars('relative')
{'V': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'spike': Variable(DeviceArray([False, False, False, False, False], dtype=bool)),
 'w': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))}

Elements in containers

One drawback of collection functions is that they don not look for elements in list, dict or any other container structure.

class ATest(bp.Base):
    def __init__(self):
        super(ATest, self).__init__()
        
        self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
        self.sub_nodes = {'a': FHN(10), 'b': FHN(5)}
t1 = ATest()

The above class defines a list of variables, and a dict of children nodes, but the variables and children nodes cannot be retrieved from the collection functions vars() and nodes().

t1.vars()
{}
t1.nodes()
{'ATest0': <__main__.ATest at 0x224fa309430>}

To solve this problem, BrianPy provides implicit_vars and implicit_nodes (an instance of “dict”) to hold variables and nodes in container structures. Variables registered in implicit_vars and integrators and nodes registered in implicit_nodes can be retrieved by collection functions.

class AnotherTest(bp.Base):
    def __init__(self):
        super(AnotherTest, self).__init__()
        
        self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
        self.sub_nodes = {'a': FHN(10, name='T1'), 'b': FHN(5, name='T2')}
        
        self.register_implicit_vars({f'v{i}': v for i, v in enumerate(self.all_vars)}  # the input must be a dict
                                    )
        self.register_implicit_nodes({k: v for k, v in self.sub_nodes.items()}  # the input must be a dict
                                     )
t2 = AnotherTest()
# This model has two "FHN" instances, each "FHN" instance has one integrator. 
# Therefore, there are five Base objects. 

t2.nodes()
{'T1': <__main__.FHN at 0x224fb8a51c0>,
 'T2': <__main__.FHN at 0x224fb8a5c70>,
 'RK49': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fb9c7190>,
 'RK410': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fb9c7400>,
 'AnotherTest0': <__main__.AnotherTest at 0x224fb8a5250>}
# This model has two FHN node, each of which has three variables.
# Moreover, this model has two implicit variables.

t2.vars()
{'T1.V': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'T1.spike': Variable(DeviceArray([False, False, False, False, False, False, False, False,
                       False, False], dtype=bool)),
 'T1.w': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'T2.V': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'T2.spike': Variable(DeviceArray([False, False, False, False, False], dtype=bool)),
 'T2.w': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'AnotherTest0.v0': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'AnotherTest0.v1': Variable(DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32))}

Saving and loading

Because Base.vars() returns a Python dictionary object Collector, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to BrainPy models. Therefore, each Base object has standard exporting and loading methods (for more details, please see Saving and Loading). Specifically, they are implemented by Base.save_states() and Base.load_states().

Save

Base.save_states(PATH, [vars])

Models exported from BrainPy support various Python standard file formats, including

  • HDF5: .h5, .hdf5

  • .npz (NumPy file format)

  • .pkl (Python’s pickle utility)

  • .mat (Matlab file format)

net.save_states('./data/net.h5')
net.save_states('./data/net.pkl')

Load


Base.load_states(PATH)
net.load_states('./data/net.h5')
net.load_states('./data/net.pkl')

Collector

Collection functions return an brainpy.Collector that is a dictionary mapping names to elements. It has some useful methods.

subset()

Collector.subset(cls) returns a part of elements whose type is the given cls. For example, Base.nodes() returns all instances of Base class. If you are only interested in one type, like ODEIntegrator, you can use:

net.nodes().subset(bp.ode.ODEIntegrator)
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fa358790>,
 'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fb878d90>}

Actually, Collector.subset(cls) travels all the elements in this collection, and find the element whose type matches the given cls.

unique()

It is common in machine learning that weights are shared with several objects, or the same weight can be accessed by various dependence relationships. Collection functions of Base usually return a collection in which the same value have multiple keys. The duplicate elements will not be automatically excluded. However, it is important not to apply operations such as gradient descent twice or more to the same elements.

Therefore, the Collector provides Collector.unique() to handle this problem automatically. Collector.unique() returns a copy of collection in which all elements are unique.

class ModelA(bp.Base):
    def __init__(self):
        super(ModelA, self).__init__()
        self.a = bm.Variable(bm.zeros(5))

        
class SharedA(bp.Base):
    def __init__(self, source):
        super(SharedA, self).__init__()
        self.source = source
        self.a = source.a  # shared variable
        
        
class Group(bp.Base):
    def __init__(self):
        super(Group, self).__init__()
        self.A = ModelA()
        self.A_shared = SharedA(self.A)

g = Group()
g.vars('relative')  # save Variable can be accessed by three paths
{'A.a': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'A_shared.a': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'A_shared.source.a': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))}
g.vars('relative').unique()  # only return a unique path
{'A.a': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))}
g.nodes('relative')  # "ModelA" is accessed twice
{'': <__main__.Group at 0x224fb9e8130>,
 'A': <__main__.ModelA at 0x224fb9e8040>,
 'A_shared': <__main__.SharedA at 0x224fb9e8280>,
 'A_shared.source': <__main__.ModelA at 0x224fb9e8040>}
g.nodes('relative').unique()
{'': <__main__.Group at 0x224fb9e8130>,
 'A': <__main__.ModelA at 0x224fb9e8040>,
 'A_shared': <__main__.SharedA at 0x224fb9e8280>}

update()

The Collector can also catch potential conflicts during the assignment. The bracket assignment of a Collector ([key]) and Collector.update() will check whether the same key is mapped to a different value. If it is, an error will occur.

tc = bp.Collector({'a': bm.zeros(10)})

tc
{'a': JaxArray(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))}
try:
    tc['a'] = bm.zeros(1)  # same key "a", different tensor
except Exception as e:
    print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [0.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].
try:
    tc.update({'a': bm.ones(1)})  # same key "a", different tensor
except Exception as e:
    print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [1.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].

replace()

Collector.replace(old_key, new_value) is used to update the value of a key.

tc
{'a': JaxArray(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))}
tc.replace('a', bm.ones(3))

tc
{'a': JaxArray(DeviceArray([1., 1., 1.], dtype=float32))}

__add()__

Two Collectors can be merged.

a = bp.Collector({'a': bm.zeros(10)})
b = bp.Collector({'b': bm.ones(10)})

a + b
{'a': JaxArray(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'b': JaxArray(DeviceArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32))}

TensorCollector

TensorCollector is subclass of Collector, but it is specifically to collect tensors.