Base Class#

@Chaoming Wang @Xiaoyu Chen

In this section, we are going to talk about:

  • The Base class for the BrainPy ecosystem

  • The Collector to facilitate variable collection and manipulation.

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

brainpy.Base#

The foundation of BrainPy is brainpy.Base. A Base instance is an object which has variables and methods. All methods in the Base object can be JIT compiled or automatically differentiated. In other words, any class objects that will be JIT compiled or automatically differentiated must inherent from brainpy.Base.

A Base object can have many variables, children Base objects, integrators, and methods. Below is the implemention of a FitzHugh-Nagumo neuron model as an example.

class FHN(bp.Base):
  def __init__(self, num, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
    super(FHN, self).__init__(name=name)

    # parameters
    self.num = num
    self.a = a
    self.b = b
    self.tau = tau
    self.Vth = Vth

    # variables
    self.V = bm.Variable(bm.zeros(num))
    self.w = bm.Variable(bm.zeros(num))
    self.spike = bm.Variable(bm.zeros(num, dtype=bool))

    # integral
    self.integral = bp.odeint(method='rk4', f=self.derivative)

  def derivative(self, V, w, t, Iext):
    dw = (V + self.a - self.b * w) / self.tau
    dV = V - V * V * V / 3 - w + Iext
    return dV, dw

  def update(self, _t, _dt, x):
    V, w = self.integral(self.V, self.w, _t, x)
    self.spike[:] = bm.logical_and(V > self.Vth, self.V <= self.Vth)
    self.w[:] = w
    self.V[:] = V

Note this model has three variables: self.V, self.w, and self.spike. It also has an integrator self.integral.

The naming system#

Every Base object has a unique name. Users can specify a unique name when you instantiate a Base class. A used name will cause an error.

FHN(10, name='X').name
'X'
FHN(10, name='Y').name
'Y'
try:
    FHN(10, name='Y').name
except Exception as e:
    print(type(e).__name__, ':', e)

If a name is not specified to the Base oject, BrainPy will assign a name for this object automatically. The rule for generating object name is class_name +  number_of_instances. For example, FHN0, FHN1, etc.

FHN(10).name
'FHN0'
FHN(10).name
'FHN1'

Therefore, in BrainPy, you can access any object by its unique name, no matter how insignificant this object is.

Collection functions#

Three important collection functions are implemented for each Base object. Specifically, they are:

  • nodes(): to collect all instances of Base objects, including children nodes in a node.

  • vars(): to collect all variables defined in the Base node and in its children nodes.

fhn = FHN(10)

All variables in a Base object can be collected through Base.vars(). The returned container is a TensorCollector (a subclass of Collector).

vars = fhn.vars()

vars
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN2.spike': Variable([False, False, False, False, False, False, False, False,
           False, False], dtype=bool),
 'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
type(vars)
brainpy.base.collector.TensorCollector

All nodes in the model can also be collected through one method Base.nodes(). The result container is an instance of Collector.

nodes = fhn.nodes()

nodes  # note: integrator is also a node
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155a7a6a90>,
 'FHN2': <__main__.FHN at 0x2155a7a65e0>}
type(nodes)
brainpy.base.collector.Collector

All integrators can be collected by:

ints = fhn.nodes().subset(bp.integrators.Integrator)

ints
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155a7a6a90>}
type(ints)
brainpy.base.collector.Collector

Now, let’s make a more complicated model by using the previously defined model FHN.

class FeedForwardCircuit(bp.Base):
    def __init__(self, num1, num2, w=0.1, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
        super(FeedForwardCircuit, self).__init__(name=name)
        
        self.pre = FHN(num1, a=a, b=b, tau=tau, Vth=Vth)
        self.post = FHN(num2, a=a, b=b, tau=tau, Vth=Vth)
        
        self.conn = bm.ones((num1, num2), dtype=bool) * w
        bm.fill_diagonal(self.conn, 0.)

    def update(self, _t, _dt, x):
        self.pre.update(_t, _dt, x)
        x2 = self.pre.spike @ self.conn
        self.post.update(_t, _dt, x2)

This model FeedForwardCircuit defines two layers. Each layer is modeled as a FitzHugh-Nagumo model (FHN). The first layer is densely connected to the second layer. The input to the second layer is the product of the first layer’s spike and the connection strength w.

net = FeedForwardCircuit(8, 5)

We can retrieve all variables by .vars():

net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN3.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
 'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN4.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'FHN4.spike': Variable([False, False, False, False, False], dtype=bool),
 'FHN4.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}

And retrieve all nodes (instances of the Base class) by .nodes():

net.nodes()
{'FHN3': <__main__.FHN at 0x2155ace3130>,
 'FHN4': <__main__.FHN at 0x2155a798d30>,
 'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>,
 'RK47': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace38b0>,
 'FeedForwardCircuit0': <__main__.FeedForwardCircuit at 0x2155a743c40>}

If we only care about a subtype of class, we can retrieve them through:

net.nodes().subset(bp.ode.ODEIntegrator)
{'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>,
 'RK47': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace38b0>}

Absolute paths#

It’s worthy to note that there are two ways to access variables, integrators, and nodes. They are “absolute” paths and “relative” paths. The default way is the absolute path.

For absolute paths, all keys in the resulting Collector (Base.nodes()) has the format of key = node_name [+ field_name].

.nodes() example 1: In the above fhn instance, there are two nodes: “fnh” and its integrator “fhn.integral”.

fhn.integral.name, fhn.name
('RK45', 'FHN2')

Calling .nodes() returns their names and models.

fhn.nodes().keys()
dict_keys(['RK45', 'FHN2'])

.nodes() example 2: In the above net instance, there are five nodes:

net.pre.name, net.post.name, net.pre.integral.name, net.post.integral.name, net.name
('FHN3', 'FHN4', 'RK46', 'RK47', 'FeedForwardCircuit0')

Calling .nodes() also returns the names and instances of all models.

net.nodes().keys()
dict_keys(['FHN3', 'FHN4', 'RK46', 'RK47', 'FeedForwardCircuit0'])

.vars() example 1: In the above fhn instance, there are three variables: “V”, “w” and “input”. Calling .vars() returns a dict of <node_name + var_name, var_value>.

fhn.vars()
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN2.spike': Variable([False, False, False, False, False, False, False, False,
           False, False], dtype=bool),
 'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}

.vars() example 2: This also applies in the net instance:

net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN3.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
 'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'FHN4.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'FHN4.spike': Variable([False, False, False, False, False], dtype=bool),
 'FHN4.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}

Relative paths#

Variables, integrators, and nodes can also be accessed by relative paths. For example, the pre instance in the net can be accessed by

net.pre
<__main__.FHN at 0x2155ace3130>

Relative paths preserve the dependence relationship. For example, all nodes retrieved from the perspective of net are:

net.nodes(method='relative')
{'': <__main__.FeedForwardCircuit at 0x2155a743c40>,
 'pre': <__main__.FHN at 0x2155ace3130>,
 'post': <__main__.FHN at 0x2155a798d30>,
 'pre.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>,
 'post.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace38b0>}

However, nodes retrieved from the start point of net.pre will be:

net.pre.nodes('relative')
{'': <__main__.FHN at 0x2155ace3130>,
 'integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>}

Variables can also br relatively inferred from the model. For example, variables that can be relatively accessed from net include:

net.vars('relative')
{'pre.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'pre.spike': Variable([False, False, False, False, False, False, False, False], dtype=bool),
 'pre.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'post.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'post.spike': Variable([False, False, False, False, False], dtype=bool),
 'post.w': Variable([0., 0., 0., 0., 0.], dtype=float32)}

While variables relatively accessed from net.post are:

net.post.vars('relative')
{'V': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'spike': Variable([False, False, False, False, False], dtype=bool),
 'w': Variable([0., 0., 0., 0., 0.], dtype=float32)}

Elements in containers#

One drawback of collection functions is that they don not look for elements in list, dict or any other container structure.

class ATest(bp.Base):
    def __init__(self):
        super(ATest, self).__init__()
        
        self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
        self.sub_nodes = {'a': FHN(10), 'b': FHN(5)}
t1 = ATest()

The above class defines a list of variables, and a dict of children nodes, but the variables and children nodes cannot be retrieved from the collection functions vars() and nodes().

t1.vars()
{}
t1.nodes()
{'ATest0': <__main__.ATest at 0x2155ae60a00>}

To solve this problem, BrianPy provides implicit_vars and implicit_nodes (an instance of “dict”) to hold variables and nodes in container structures. Variables registered in implicit_vars and integrators and nodes registered in implicit_nodes can be retrieved by collection functions.

class AnotherTest(bp.Base):
    def __init__(self):
        super(AnotherTest, self).__init__()
        
        self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
        self.sub_nodes = {'a': FHN(10, name='T1'), 'b': FHN(5, name='T2')}
        
        self.register_implicit_vars({f'v{i}': v for i, v in enumerate(self.all_vars)}  # the input must be a dict
                                    )
        self.register_implicit_nodes({k: v for k, v in self.sub_nodes.items()}  # the input must be a dict
                                     )
t2 = AnotherTest()
# This model has two "FHN" instances, each "FHN" instance has one integrator. 
# Therefore, there are five Base objects. 

t2.nodes()
{'T1': <__main__.FHN at 0x2155ae6bca0>,
 'T2': <__main__.FHN at 0x2155ae6b3a0>,
 'RK410': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ae6f8e0>,
 'RK411': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ae66610>,
 'AnotherTest0': <__main__.AnotherTest at 0x2155ae6f0d0>}
# This model has two FHN node, each of which has three variables.
# Moreover, this model has two implicit variables.

t2.vars()
{'T1.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'T1.spike': Variable([False, False, False, False, False, False, False, False,
           False, False], dtype=bool),
 'T1.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'T2.V': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'T2.spike': Variable([False, False, False, False, False], dtype=bool),
 'T2.w': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'AnotherTest0.v0': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'AnotherTest0.v1': Variable([1., 1., 1., 1., 1., 1.], dtype=float32)}

Saving and loading#

Because Base.vars() returns a Python dictionary object Collector, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to BrainPy models. Therefore, each Base object has standard exporting and loading methods (for more details, please see Saving and Loading). Specifically, they are implemented by Base.save_states() and Base.load_states().

Save#

Base.save_states(PATH, [vars])

Models exported from BrainPy support various Python standard file formats, including

  • HDF5: .h5, .hdf5

  • .npz (NumPy file format)

  • .pkl (Python’s pickle utility)

  • .mat (Matlab file format)

net.save_states('./data/net.h5')
net.save_states('./data/net.pkl')

Load#


Base.load_states(PATH)
net.load_states('./data/net.h5')
net.load_states('./data/net.pkl')

Collector#

Collection functions return an brainpy.Collector that is a dictionary mapping names to elements. It has some useful methods.

subset()#

Collector.subset(cls) returns a part of elements whose type is the given cls. For example, Base.nodes() returns all instances of Base class. If you are only interested in one type, like ODEIntegrator, you can use:

net.nodes().subset(bp.ode.ODEIntegrator)
{'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace35e0>,
 'RK47': <brainpy.integrators.ode.explicit_rk.RK4 at 0x2155ace38b0>}

Actually, Collector.subset(cls) travels all the elements in this collection, and find the element whose type matches the given cls.

unique()#

It is common in machine learning that weights are shared with several objects, or the same weight can be accessed by various dependence relationships. Collection functions of Base usually return a collection in which the same value have multiple keys. The duplicate elements will not be automatically excluded. However, it is important not to apply operations such as gradient descent twice or more to the same elements.

Therefore, the Collector provides Collector.unique() to handle this problem automatically. Collector.unique() returns a copy of collection in which all elements are unique.

class ModelA(bp.Base):
    def __init__(self):
        super(ModelA, self).__init__()
        self.a = bm.Variable(bm.zeros(5))

        
class SharedA(bp.Base):
    def __init__(self, source):
        super(SharedA, self).__init__()
        self.source = source
        self.a = source.a  # shared variable
        
        
class Group(bp.Base):
    def __init__(self):
        super(Group, self).__init__()
        self.A = ModelA()
        self.A_shared = SharedA(self.A)

g = Group()
g.vars('relative')  # save Variable can be accessed by three paths
{'A.a': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'A_shared.a': Variable([0., 0., 0., 0., 0.], dtype=float32),
 'A_shared.source.a': Variable([0., 0., 0., 0., 0.], dtype=float32)}
g.vars('relative').unique()  # only return a unique path
{'A.a': Variable([0., 0., 0., 0., 0.], dtype=float32)}
g.nodes('relative')  # "ModelA" is accessed twice
{'': <__main__.Group at 0x2155b13b550>,
 'A': <__main__.ModelA at 0x2155b13b460>,
 'A_shared': <__main__.SharedA at 0x2155a7a6580>,
 'A_shared.source': <__main__.ModelA at 0x2155b13b460>}
g.nodes('relative').unique()
{'': <__main__.Group at 0x2155b13b550>,
 'A': <__main__.ModelA at 0x2155b13b460>,
 'A_shared': <__main__.SharedA at 0x2155a7a6580>}

update()#

The Collector can also catch potential conflicts during the assignment. The bracket assignment of a Collector ([key]) and Collector.update() will check whether the same key is mapped to a different value. If it is, an error will occur.

tc = bp.Collector({'a': bm.zeros(10)})

tc
{'a': JaxArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
try:
    tc['a'] = bm.zeros(1)  # same key "a", different tensor
except Exception as e:
    print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [0.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].
try:
    tc.update({'a': bm.ones(1)})  # same key "a", different tensor
except Exception as e:
    print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [1.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].

replace()#

Collector.replace(old_key, new_value) is used to update the value of a key.

tc
{'a': JaxArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
tc.replace('a', bm.ones(3))

tc
{'a': JaxArray([1., 1., 1.], dtype=float32)}

__add()__#

Two Collectors can be merged.

a = bp.Collector({'a': bm.zeros(10)})
b = bp.Collector({'b': bm.ones(10)})

a + b
{'a': JaxArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'b': JaxArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)}

TensorCollector#

TensorCollector is subclass of Collector, but it is specifically to collect tensors.