Base Class

In this section, we are going to talk about:

  • Base class for BrainPy ecosystem,

  • Collector to facilitate variable collection and manipulation.

import brainpy as bp
import brainpy.math as bm

Base

The foundation of BrainPy is brainpy.Base. A Base instance is an object which has variables and methods. All methods in the Base object can be JIT compiled or automatic differentiated. Or we can say, any class objects want to JIT compile or auto differentiate must inherent from brainpy.Base.

A Base object can have many variables, children Base objects, integrators, and methods. For example, let’s implement a FitzHugh-Nagumo neuron model.

class FHN(bp.Base):
  def __init__(self, num, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
    super(FHN, self).__init__(name=name)

    # parameters
    self.num = num
    self.a = a
    self.b = b
    self.tau = tau
    self.Vth = Vth

    # variables
    self.V = bm.Variable(bm.zeros(num))
    self.w = bm.Variable(bm.zeros(num))
    self.spike = bm.Variable(bm.zeros(num, dtype=bool))

    # integral
    self.integral = bp.odeint(method='rk4', f=self.derivative)

  def derivative(self, V, w, t, Iext):
    dw = (V + self.a - self.b * w) / self.tau
    dV = V - V * V * V / 3 - w + Iext
    return dV, dw

  def update(self, _t, _dt, x):
    V, w = self.integral(self.V, self.w, _t, x)
    self.spike[:] = bm.logical_and(V > self.Vth, self.V <= self.Vth)
    self.w[:] = w
    self.V[:] = V

Note this model has three variables: self.V, self.w, and self.spike. It also has an integrator self.integral.

Naming system

Every Base object has a unique name. You can specify a unique name when you instantiate a Base class. A used name will cause an error.

FHN(10, name='X').name
'X'
FHN(10, name='Y').name
'Y'
try:
    FHN(10, name='Y').name
except Exception as e:
    print(type(e).__name__, ':', e)
UniqueNameError : In BrainPy, each object should have a unique name. However, we detect that <__main__.FHN object at 0x7f4a7406bd60> has a used name "Y".

When you instance a Base class without “name” specification, BrainPy will assign a name for this object automatically. The rule for generating object name is class_name +  number_of_instances. For example, FHN0, FHN1, etc.

FHN(10).name
'FHN0'
FHN(10).name
'FHN1'

Therefore in BrainPy, you can access any object by its unique name, no matter how insignificant this object is.

Collection functions

Three important collection functions are implemented for each Base object. Specifically, they are:

  • nodes(): to collect all instances of Base objects, including children nodes in a node.

  • ints(): to collect all integrators defined in the Base node and in its children nodes.

  • vars(): to collect all variables defined in the Base node and in its children nodes.

All integrators can be collected through one method Base.ints(). The result container is a Collector.

fhn = FHN(10)
ints = fhn.ints()

ints
{'FHN2.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7406b430>}
type(ints)
brainpy.base.collector.Collector

Similarly, all variables in a Base object can be collected through Base.vars(). The returned container is a TensorCollector (a subclass of Collector).

vars = fhn.vars()

vars
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'FHN2.spike': Variable([False, False, False, False, False, False, False, False, False,
           False]),
 'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}
type(vars)
brainpy.base.collector.TensorCollector

All nodes in the model can also be collected through one method Base.nodes(). The result container is an instance of Collector.

nodes = fhn.nodes()

nodes  # note: integrator is also a node
{'RK44': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7406b430>,
 'FHN2': <__main__.FHN at 0x7f4a7406b5e0>}
type(nodes)
brainpy.base.collector.Collector

Now, let’s make a more complicated model by using the previously defined model FHN.

class FeedForwardCircuit(bp.Base):
    def __init__(self, num1, num2, w=0.1, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
        super(FeedForwardCircuit, self).__init__(name=name)
        
        self.pre = FHN(num1, a=a, b=b, tau=tau, Vth=Vth)
        self.post = FHN(num2, a=a, b=b, tau=tau, Vth=Vth)
        
        conn = bm.ones((num1, num2), dtype=bool)
        self.conn = bm.fill_diagonal(conn, False) * w

    def update(self, _t, _dt, x):
        self.pre.update(_t, _dt, x)
        x2 = self.pre.spike @ self.conn
        self.post.update(_t, _dt, x2)

This model FeedForwardCircuit defines two layers. Each layer is modeled as a FitzHugh-Nagumo model (FHN). The first layer is densely connected to the second layer. The input to the second layer is the first layer’s spike times a connection strength w.

net = FeedForwardCircuit(8, 5)

We can retrieve all integrators in the network with .ints() :

net.ints()
{'FHN3.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>,
 'FHN4.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7401b100>}

Or, retrieve all variables by .vars():

net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
 'FHN3.spike': Variable([False, False, False, False, False, False, False, False]),
 'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
 'FHN4.V': Variable([0., 0., 0., 0., 0.]),
 'FHN4.spike': Variable([False, False, False, False, False]),
 'FHN4.w': Variable([0., 0., 0., 0., 0.])}

Or, retrieve all nodes (instances of Base class) with .nodes():

net.nodes()
{'FHN3': <__main__.FHN at 0x7f4a74077670>,
 'FHN4': <__main__.FHN at 0x7f4a740771c0>,
 'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>,
 'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7401b100>,
 'FeedForwardCircuit0': <__main__.FeedForwardCircuit at 0x7f4a74077790>}

Absolute path

It’s worthy to note that there are two types of ways to access variables, integrators, and nodes. They are “absolute” path and “relative” path. The default way is the absolute path.

“Absolute” path means that all keys in the resulting Collector (Base.nodes()) has the format of key = node_name [+ field_name].

.nodes() example 1: In the above fhn instance, there are two nodes: “fnh” and its integrator “fhn.integral”.

fhn.integral.name, fhn.name
('RK44', 'FHN2')

Calling .nodes() returns models’ name and models.

fhn.nodes().keys()
dict_keys(['RK44', 'FHN2'])

.nodes() example 2: In the above net instance, there are five nodes:

net.pre.name, net.post.name, net.pre.integral.name, net.post.integral.name, net.name
('FHN3', 'FHN4', 'RK45', 'RK46', 'FeedForwardCircuit0')

Calling .nodes() also returns the names and instances of all models.

net.nodes().keys()
dict_keys(['FHN3', 'FHN4', 'RK45', 'RK46', 'FeedForwardCircuit0'])

.vars() example 1: In the above fhn instance, there are three variables: “V”, “w” and “input”. Calling .vars() returns a dict of <node_name + var_name, var_value>.

fhn.vars()
{'FHN2.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'FHN2.spike': Variable([False, False, False, False, False, False, False, False, False,
           False]),
 'FHN2.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}

.vars() example 2: This also applies in the net instance:

net.vars()
{'FHN3.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
 'FHN3.spike': Variable([False, False, False, False, False, False, False, False]),
 'FHN3.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
 'FHN4.V': Variable([0., 0., 0., 0., 0.]),
 'FHN4.spike': Variable([False, False, False, False, False]),
 'FHN4.w': Variable([0., 0., 0., 0., 0.])}

Relative path

Variables, integrators, and nodes can also be accessed by relative path. For example, the pre instance in the net can be accessed by

net.pre
<__main__.FHN at 0x7f4a74077670>

Relative path preserves the dependence relationship. For example, all nodes retrieved from the perspective of net are:

net.nodes(method='relative')
{'': <__main__.FeedForwardCircuit at 0x7f4a74077790>,
 'pre': <__main__.FHN at 0x7f4a74077670>,
 'post': <__main__.FHN at 0x7f4a740771c0>,
 'pre.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>,
 'post.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7401b100>}

However, nodes retrieved from the start point of net.pre will be:

net.pre.nodes('relative')
{'': <__main__.FHN at 0x7f4a74077670>,
 'integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>}

Variables can also relatively inferred from the model. For example, all variables one can relatively accessed from net are:

net.vars('relative')
{'pre.V': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
 'pre.spike': Variable([False, False, False, False, False, False, False, False]),
 'pre.w': Variable([0., 0., 0., 0., 0., 0., 0., 0.]),
 'post.V': Variable([0., 0., 0., 0., 0.]),
 'post.spike': Variable([False, False, False, False, False]),
 'post.w': Variable([0., 0., 0., 0., 0.])}

While, variables relatively accessed from the view of net.post are:

net.post.vars('relative')
{'V': Variable([0., 0., 0., 0., 0.]),
 'spike': Variable([False, False, False, False, False]),
 'w': Variable([0., 0., 0., 0., 0.])}

Elements in containers

To avoid surprising unintended behaviors, collection functions don’t look for elements in list, dict or any container structure.

class ATest(bp.Base):
    def __init__(self):
        super(ATest, self).__init__()
        
        self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
        self.sub_nodes = {'a': FHN(10), 'b': FHN(5)}
t1 = ATest()

The above class define a list of variables, and a dict of children nodes. However, they can not be retrieved from the collection functions vars() and nodes().

t1.vars()
{}
t1.nodes()
{'ATest0': <__main__.ATest at 0x7f4a7401b2b0>}

Fortunately, in BrianPy, we provide implicit_vars and implicit_nodes (an instance of “dict”) to hold variables and nodes in container structures. Any variable registered in implicit_vars, or any integrator or node registered in implicit_nodes can be retrieved by collection functions. Let’s make a try.

class AnotherTest(bp.Base):
    def __init__(self):
        super(AnotherTest, self).__init__()
        
        self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
        self.sub_nodes = {'a': FHN(10, name='T1'), 'b': FHN(5, name='T2')}
        
        self.implicit_vars = {f'v{i}': v for i, v in enumerate(self.all_vars)}  # must be a dict
        self.implicit_nodes = {k: v for k, v in self.sub_nodes.items()}  # must be a dict
t2 = AnotherTest()
# This model has two "FHN" instances, each "FHN" instance has one integrator. 
# Therefore, there are five Base objects. 

t2.nodes()
{'T1': <__main__.FHN at 0x7f4a740777c0>,
 'T2': <__main__.FHN at 0x7f4a74077a90>,
 'RK49': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74077340>,
 'RK410': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74011040>,
 'AnotherTest0': <__main__.AnotherTest at 0x7f4a740157f0>}
# This model has five Base objects (seen above), 
# each FHN node has three variables, 
# moreover, this model has two implicit variables.

t2.vars()
{'T1.V': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'T1.spike': Variable([False, False, False, False, False, False, False, False, False,
           False]),
 'T1.w': Variable([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'T2.V': Variable([0., 0., 0., 0., 0.]),
 'T2.spike': Variable([False, False, False, False, False]),
 'T2.w': Variable([0., 0., 0., 0., 0.]),
 'AnotherTest0.v0': Variable([0., 0., 0., 0., 0.]),
 'AnotherTest0.v1': Variable([1., 1., 1., 1., 1., 1.])}

Saving and loading

Because Base.vars() returns a Python dictionary object Collector, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to BrainPy models. Therefore, each Base object has standard exporting and loading methods (more details please see Saving and Loading). Specifically, they are implemented by Base.save_states() and Base.load_states().

Save

Base.save_states(PATH, [vars])

Model exporting in BrainPy supports various Python standard file formats, including

  • HDF5: .h5, .hdf5

  • .npz (NumPy file format)

  • .pkl (Python’s pickle utility)

  • .mat (Matlab file format)

net.save_states('./data/net.h5')
net.save_states('./data/net.pkl')
# Unknown file format will cause error

try:
    net.save_states('./data/net.xxx')
except Exception as e:
    print(type(e).__name__, ":", e)
BrainPyError : Unknown file format: ./data/net.xxx. We only supports ['.h5', '.hdf5', '.npz', '.pkl', '.mat']

Load


Base.load_states(PATH)
net.load_states('./data/net.h5')
net.load_states('./data/net.pkl')

Collector

Collection functions returns an brainpy.Collector. This class is a dictionary that maps names to elements. It has some useful methods.

subset()

Collector.subset(cls) returns a part of elements whose type is the given cls. For example, Base.nodes() returns all instances of Base class. If you are only interested in one type, like ODEIntegrator, you can use:

net.nodes().subset(bp.ode.ODEIntegrator)
{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a74015670>,
 'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x7f4a7401b100>}

Actually, Collector.subset(cls) travels all the elements in this collection, and find the element whose type matches to the given cls.

unique()

It’s a common in machine learning that weights are shared with several objects, or the same weight can be accessed by various dependence relationships. Collection functions of Base usually return a collection in which the same value have multiple keys. The duplicate elements will not be automatically excluded. However, it is important not to apply operations twice or more to the same elements (e.g., apply gradients and update weights).

Therefore, Collector provides method Collector.unique() to handle this automatically. Collector.unique() returns a copy of collection in which all elements are unique.

class ModelA(bp.Base):
    def __init__(self):
        super(ModelA, self).__init__()
        self.a = bm.Variable(bm.zeros(5))

        
class SharedA(bp.Base):
    def __init__(self, source):
        super(SharedA, self).__init__()
        self.source = source
        self.a = source.a  # shared variable
        
        
class Group(bp.Base):
    def __init__(self):
        super(Group, self).__init__()
        self.A = ModelA()
        self.A_shared = SharedA(self.A)

g = Group()
g.vars('relative')  # save Variable can be accessed by three paths
{'A.a': Variable([0., 0., 0., 0., 0.]),
 'A_shared.a': Variable([0., 0., 0., 0., 0.]),
 'A_shared.source.a': Variable([0., 0., 0., 0., 0.])}
g.vars('relative').unique()  # only return a unique path
{'A.a': Variable([0., 0., 0., 0., 0.])}
g.nodes('relative')  # "ModelA" is accessed twice
{'': <__main__.Group at 0x7f4a74049190>,
 'A': <__main__.ModelA at 0x7f4a740490a0>,
 'A_shared': <__main__.SharedA at 0x7f4a74049550>,
 'A_shared.source': <__main__.ModelA at 0x7f4a740490a0>}
g.nodes('relative').unique()
{'': <__main__.Group at 0x7f4a74049190>,
 'A': <__main__.ModelA at 0x7f4a740490a0>,
 'A_shared': <__main__.SharedA at 0x7f4a74049550>}

update()

Collector is a dict. But, it has means to catch potential conflicts during assignment. The bracket assignment of a Collector ([key]) and Collector.update() will check whether the same key maps to a different value. If yes, an error will raise.

tc = bp.Collector({'a': bm.zeros(10)})

tc
{'a': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}
try:
    tc['a'] = bm.zeros(1)  # same key "a", different tensor
except Exception as e:
    print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [0.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].
try:
    tc.update({'a': bm.ones(1)})  # same key "a", different tensor
except Exception as e:
    print(type(e).__name__, ":", e)
ValueError : Name "a" conflicts: same name for [1.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].

replace()

If you try to replace the old key with the new value, you should use Collector.replace(old_key, new_value) function.

tc
{'a': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])}
tc.replace('a', bm.ones(3))

tc
{'a': array([1., 1., 1.])}

TensorCollector

TensorCollector is subclass of Collector, but it is specifically to collect tensors.