# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import inspect
from brainpy._errors import DiffEqError
from brainpy.math.object_transform.base import Collector
__all__ = [
'JointEq',
]
def _get_args(f):
"""Get the function arguments"""
args = []
kwargs = {}
for name, par in inspect.signature(f).parameters.items():
if par.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
if par.default is inspect._empty:
args.append(par.name)
else:
kwargs[par.name] = par.default
elif par.kind is inspect.Parameter.VAR_POSITIONAL:
raise DiffEqError(f'{JointEq.__name__} does not support VAR_POSITIONAL parameters '
f'*{par.name} (error in {f}).')
elif par.kind is inspect.Parameter.KEYWORD_ONLY:
raise DiffEqError(f'{JointEq.__name__} does not support KEYWORD_ONLY parameters, '
f'e.g., * (error in {f}).')
elif par.kind is inspect.Parameter.POSITIONAL_ONLY:
raise DiffEqError(f'{JointEq.__name__} does not support POSITIONAL_ONLY parameters, '
'e.g., / (error in {f}).')
elif par.kind is inspect.Parameter.VAR_KEYWORD:
raise DiffEqError(f'{JointEq.__name__} does not support VAR_KEYWORD '
f'arguments **{par.name} (error in {f}).')
else:
raise DiffEqError(f'Unknown argument type: {par.kind}')
# variables
vars = []
for a in args:
if a == 't':
break
vars.append(a)
else:
raise ValueError('Do not find time variable "t".')
return vars, args, kwargs
def _std_func(f, all_vars: list):
f_vars, f_args, f_kwargs = _get_args(f)
def call(t, *vars, **args_and_kwargs):
params = dict(t=t)
for var in f_vars:
params[var] = vars[all_vars.index(var)]
for par in f_args[len(f_vars) + 1:]:
if par in args_and_kwargs:
params[par] = args_and_kwargs[par]
else:
if par not in all_vars:
raise DiffEqError(f'Missing {par} during the functional call of {f}.')
params[par] = vars[all_vars.index(par)]
for par, value in f_kwargs.items():
if par in args_and_kwargs:
params[par] = args_and_kwargs[par]
return f(**params)
return call
[docs]
class JointEq(object):
"""Make a joint equation from multiple derivation functions.
For example, we have an Izhikevich neuron model,
>>> a, b = 0.02, 0.20
>>> dV = lambda V, t, u, Iext: 0.04 * V * V + 5 * V + 140 - u + Iext
>>> du = lambda u, t, V: a * (b * V - u)
If we make numerical solver for each derivative function, they will be solved independently.
>>> import brainpy as bp
>>> bp.odeint(dV, method='rk2', show_code=True)
def brainpy_itg_of_ode0(V, t, u, Iext, dt=0.1):
dV_k1 = f(V, t, u, Iext)
k2_V_arg = V + dt * dV_k1 * 0.6666666666666666
k2_t_arg = t + dt * 0.6666666666666666
dV_k2 = f(k2_V_arg, k2_t_arg, u, Iext)
V_new = V + dV_k1 * dt * 0.25 + dV_k2 * dt * 0.75
return V_new
As you see in the output code, "dV_k2" is evaluated by :math:`f(V_{k2}, u)`.
If you want to solve the above coupled equation jointly, i.e., evalute "dV_k2"
with :math:`f(V_{k2}, u_{k2})`, you can use :py:class:`brainpy.JointEq`
to emerge the above two derivative equations into a joint equation, so that
they will be numerically solved together. Let's see the difference:
>>> eq = bp.JointEq(eqs=(dV, du))
>>> bp.odeint(eq, method='rk2', show_code=True)
def brainpy_itg_of_ode0_joint_eq(V, u, t, Iext, dt=0.1):
dV_k1, du_k1 = f(V, u, t, Iext)
k2_V_arg = V + dt * dV_k1 * 0.6666666666666666
k2_u_arg = u + dt * du_k1 * 0.6666666666666666
k2_t_arg = t + dt * 0.6666666666666666
dV_k2, du_k2 = f(k2_V_arg, k2_u_arg, k2_t_arg, Iext)
V_new = V + dV_k1 * dt * 0.25 + dV_k2 * dt * 0.75
u_new = u + du_k1 * dt * 0.25 + du_k2 * dt * 0.75
return V_new, u_new
:py:class:`brainpy.JointEq` supports make nested ``JointEq``, which means
the instance of ``JointEq`` can be an element to compose a new ``JointEq``.
>>> dw = lambda w, t, V: a * (b * V - w)
>>> eq2 = bp.JointEq(eqs=(eq, dw))
Parameters::
*eqs :
The elements of derivative function to compose.
"""
def _check_eqs(self, eqs):
for eq in eqs:
if isinstance(eq, (list, tuple)):
for a in self._check_eqs(eq):
yield a
elif callable(eq):
yield eq
else:
raise DiffEqError(f'Elements in "eqs" only supports callable function, but got {eq}.')
def __init__(self, *eqs):
eqs = list(self._check_eqs(eqs))
# variables in equations
self.vars_in_eqs = []
vars_in_eqs = []
for eq in eqs:
vars, _, _ = _get_args(eq)
for var in vars:
if var in vars_in_eqs:
raise DiffEqError(
f'Variable "{var}" has been used, however we got a same '
f'variable name in {eq}.\n\n'
f'In JointEq, each state variable should appear as the first parameter '
f'before "t" in exactly one derivative function. If "{var}" is a state '
f'variable in another equation, it should be placed AFTER "t" in this '
f'function as a dependency.\n\n'
f'Correct signature pattern:\n'
f' def d{var}({var}, t, <dependencies>): ... # {var} is the state variable\n'
f' def dOther(other, t, {var}): ... # {var} is a dependency\n\n'
f'Current function signature: {inspect.signature(eq)}'
)
vars_in_eqs.extend(vars)
self.vars_in_eqs.append(vars)
# arguments in equations
self.args_in_eqs = []
all_arg_pars = []
all_kwarg_pars = dict()
for eq in eqs:
vars, args, kwargs = _get_args(eq)
self.args_in_eqs.append(args + list(kwargs.keys()))
for par in args[len(vars) + 1:]:
if (par not in vars_in_eqs) and (par not in all_arg_pars) and (par not in all_kwarg_pars):
all_arg_pars.append(par)
for key, value in kwargs.items():
if key in all_kwarg_pars and value != all_kwarg_pars[key]:
raise DiffEqError(f'We got two different default value of "{key}": '
f'{all_kwarg_pars[key]} != {value}')
elif (key not in vars_in_eqs) and (key not in all_arg_pars):
all_kwarg_pars[key] = value
else:
raise DiffEqError
# # variable names provided
# if not isinstance(variables, (tuple, list)):
# raise DiffEqError(f'"variables" must be a list/tuple of str, but we got {variables}')
# for v in variables:
# if not isinstance(v, str):
# raise DiffEqError(f'"variables" must be a list/tuple of str, but we got {v} in "variables"')
# if len(vars_in_eqs) != len(variables):
# raise DiffEqError(f'We detect {len(vars_in_eqs)} variables "{vars_in_eqs}" '
# f'in the provided equations. However, the used provided '
# f'"variables" have {len(variables)} variables '
# f'"{variables}".')
# if len(set(vars_in_eqs) - set(variables)) != 0:
# raise DiffEqError(f'We detect there are variable "{vars_in_eqs}" in the provided '
# f'equations, while the user provided variables "{variables}" '
# f'is not the same.')
# finally
self.eqs = eqs
# self.variables = variables
self.arg_keys = vars_in_eqs + ['t'] + all_arg_pars
self.kwarg_keys = list(all_kwarg_pars.keys())
self.kwargs = all_kwarg_pars
parameters = [inspect.Parameter(vp, inspect.Parameter.POSITIONAL_OR_KEYWORD)
for vp in self.arg_keys]
parameters.extend([inspect.Parameter(k,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=all_kwarg_pars[k])
for k in self.kwarg_keys])
signature = inspect.signature(eqs[0])
self.__signature__ = signature.replace(parameters=parameters)
self.__name__ = 'joint_eq'
def __call__(self, *args, **kwargs):
# format arguments
params_in = Collector()
for i, arg in enumerate(args):
if i < len(self.arg_keys):
params_in[self.arg_keys[i]] = arg
else:
params_in[self.kwarg_keys[i - len(self.arg_keys)]] = arg
params_in.update(kwargs)
# call equations
results = []
for i, eq in enumerate(self.eqs):
r = eq(**{arg: params_in[arg] for arg in self.args_in_eqs[i]})
if isinstance(r, (list, tuple)):
results.extend(list(r))
else:
results.append(r)
return results