# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import time
import warnings
from functools import partial
from typing import Union, Dict, Sequence, Callable
import jax
import jax.numpy as jnp
import numpy as np
import tqdm.auto
from jax.tree_util import tree_flatten
from brainpy import math as bm
from brainpy._errors import RunningError
from brainpy.math.object_transform.base import Collector
from brainpy.running.runner import Runner
from .base import Integrator
__all__ = [
'IntegratorRunner',
]
[docs]
class IntegratorRunner(Runner):
"""Structural runner for numerical integrators in brainpy.
Examples::
Example to run an ODE integrator,
>>> import brainpy as bp
>>> import brainpy.math as bm
>>> a=0.7; b=0.8; tau=12.5
>>> dV = lambda V, t, w, I: V - V * V * V / 3 - w + I
>>> dw = lambda w, t, V, a, b: (V + a - b * w) / tau
>>> integral = bp.odeint(bp.JointEq([dV, dw]), method='exp_auto')
>>>
>>> runner = bp.IntegratorRunner(
>>> integral, # the simulation target
>>> monitors=['V', 'w'], # the variables to monitor
>>> inits={'V': bm.random.rand(10),
>>> 'w': bm.random.normal(size=10)}, # the initial values
>>> )
>>> runner.run(100.,
>>> args={'a': 1., 'b': 1.}, # update arguments
>>> dyn_args={'I': bp.inputs.ramp_input(0, 4, 100)}, # each time each current input
>>> )
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, plot_ids=[0, 1, 4], show=True)
Example to run an SDE intragetor,
>>> import brainpy as bp
>>> import brainpy.math as bm
>>> # stochastic Lorenz system
>>> sigma=10; beta=8 / 3; rho=28
>>> g = lambda x, y, z, t, p: (p * x, p * y, p * z)
>>> f = lambda x, y, z, t, p: [sigma * (y - x), x * (rho - z) - y, x * y - beta * z]
>>> lorenz = bp.sdeint(f, g, method='milstein2')
>>>
>>> runner = bp.IntegratorRunner(
>>> lorenz,
>>> monitors=['x', 'y', 'z'],
>>> inits=[1., 1., 1.], # initialize all variable to 1.
>>> dt=0.01
>>> )
>>> runner.run(100., args={'p': 0.1},)
>>>
>>> import matplotlib.pyplot as plt
>>> fig = plt.figure()
>>> ax = fig.gca(projection='3d')
>>> plt.plot(runner.mon.x.squeeze(), runner.mon.y.squeeze(), runner.mon.z.squeeze())
>>> ax.set_xlabel('x')
>>> ax.set_xlabel('y')
>>> ax.set_xlabel('z')
>>> plt.show()
"""
def __init__(
self,
target: Integrator,
# IntegratorRunner specific arguments
inits: Union[Sequence, Dict] = None,
# regular/common arguments
dt: Union[float, int] = None,
monitors: Sequence[str] = None,
dyn_vars: Dict[str, bm.Variable] = None,
jit: Union[bool, Dict[str, bool]] = True,
numpy_mon_after_run: bool = True,
progress_bar: bool = True,
# deprecated
args: Dict = None,
dyn_args: Dict[str, Union[bm.ndarray, jnp.ndarray]] = None,
fun_monitors: Dict[str, Callable] = None,
):
"""Initialization of structural runner for integrators.
Parameters::
target: Integrator
The target to run.
monitors: sequence of str
The variables to monitor.
fun_monitors: dict
The monitors with callable functions.
.. deprecated:: 2.3.1
inits: sequence, dict
The initial value of variables. With this parameter,
you can easily control the number of variables to simulate.
For example, if one of the variable has the shape of 10,
then all variables will be an instance of :py:class:`brainpy.math.Variable`
with the shape of :math:`(10,)`.
args: dict
The equation arguments to update.
Note that if one of the arguments are heterogeneous (i.e., a tensor),
it means we should run multiple trials. However, you can set the number
of the elements in the variables so that each pair of variables can
correspond to one set of arguments.
.. deprecated:: 2.3.1
Will be removed after version 2.4.0.
dyn_args: dict
The dynamically changed arguments. This means this argument can control
the argument dynamically changed. For example, if you want to inject a
time varied currents into the HH neuron model, you can pack the currents
into this ``dyn_args`` argument.
.. deprecated:: 2.3.1
Will be removed after version 2.4.0.
dt: float, int
dyn_vars: dict
jit: bool
progress_bar: bool
numpy_mon_after_run: bool
"""
if not isinstance(target, Integrator):
raise TypeError(f'Target must be instance of {Integrator.__name__}, '
f'but we got {type(target)}')
# get maximum size and initial variables
if inits is not None:
if isinstance(inits, (list, tuple, bm.Array, jnp.ndarray)):
assert len(target.variables) == len(inits)
inits = {k: inits[i] for i, k in enumerate(target.variables)}
assert isinstance(inits, dict), f'"inits" must be a dict, but we got {type(inits)}'
sizes = np.unique([np.size(v) for v in list(inits.values())])
max_size = np.max(sizes)
else:
max_size = 1
inits = dict()
# initialize variables
self.variables = {v: bm.Variable(bm.zeros(max_size)) for v in target.variables}
for k in inits.keys():
self.variables[k][:] = inits[k]
# format string monitors
if isinstance(monitors, (tuple, list)):
monitors = self._format_seq_monitors(monitors)
monitors = {k: (self.variables[k], i) for k, i in monitors}
elif isinstance(monitors, dict):
monitors = self._format_dict_monitors(monitors)
monitors = {k: ((self.variables[i], i) if isinstance(i, str) else i) for k, i in monitors.items()}
else:
raise ValueError
# initialize super class
super(IntegratorRunner, self).__init__(target=target,
monitors=monitors,
fun_monitors=fun_monitors,
jit=jit,
progress_bar=progress_bar,
dyn_vars=dyn_vars,
numpy_mon_after_run=numpy_mon_after_run)
self.register_implicit_vars(self.variables)
# parameters
dt = bm.get_dt() if dt is None else dt
if not isinstance(dt, (int, float)):
raise RunningError(f'"dt" must be scalar, but got {dt}')
self.dt = dt
# target
if not isinstance(self.target, Integrator):
raise RunningError(f'"target" must be an instance of {Integrator.__name__}, '
f'but we got {type(target)}: {target}')
# arguments of the integral function
if args is not None:
warnings.warn('Set "args" in `IntegratorRunner.run()` function, instead of __init__ function. '
'Will be removed since 2.4.0',
UserWarning)
assert isinstance(args, dict), (f'"args" must be a dict, but '
f'we got {type(args)}: {args}')
self._static_args = args
else:
self._static_args = dict()
if dyn_args is not None:
warnings.warn('Set "dyn_args" in `IntegratorRunner.run()` function, instead of __init__ function. '
'Will be removed since 2.4.0',
UserWarning)
assert isinstance(dyn_args, dict), (f'"dyn_args" must be a dict, but we get '
f'{type(dyn_args)}: {dyn_args}')
sizes = np.unique([len(v) for v in dyn_args.values()])
num_size = len(sizes)
if num_size != 1:
raise RunningError(f'All values in "dyn_args" should have the same length. '
f'But we got {num_size}: {sizes}')
self._dyn_args = dyn_args
else:
self._dyn_args = dict()
# start simulation time and index
self.start_t = bm.Variable(bm.zeros(1))
self.idx = bm.Variable(bm.zeros(1, dtype=bm.int_))
def _run_fun_integration(self, static_args, dyn_args, times, indices):
return bm.for_loop(partial(self._step_fun_integrator, static_args),
(dyn_args, times, indices),
jit=self.jit['predict'])
def _step_fun_integrator(self, static_args, dyn_args, t, i):
# arguments
kwargs = Collector(dt=self.dt, t=t)
kwargs.update(static_args)
kwargs.update(dyn_args)
kwargs.update({k: v.value for k, v in self.variables.items()})
# call integrator function
update_values = self.target(**kwargs)
if len(self.target.variables) == 1:
self.variables[self.target.variables[0]].update(update_values)
else:
for i, v in enumerate(self.target.variables):
self.variables[v].update(update_values[i])
# progress bar
if self.progress_bar:
jax.debug.callback(lambda *args: self._pbar.update(), ())
# return of function monitors
shared = dict(t=t + self.dt, dt=self.dt, i=i)
returns = dict()
for k, v in self._monitors.items():
if callable(v):
returns[k] = bm.as_jax(v(shared))
else:
returns[k] = self.variables[k].value
return returns
[docs]
def run(
self,
duration: float,
start_t: float = None,
eval_time: bool = False,
args: Dict = None,
dyn_args: Dict = None,
):
"""The running function.
Parameters::
duration : float, int, tuple, list
The running duration.
start_t : float, optional
The start time to simulate.
eval_time: bool
Evaluate the running time or not?
args: dict
The equation arguments to update.
.. versionadded:: 2.3.1
dyn_args: dict
The dynamically changed arguments over time. The size of first dimension should be
equal to the running ``duration``.
.. versionadded:: 2.3.1
"""
args = dict() if args is None else args
dyn_args = dict() if dyn_args is None else dyn_args
assert isinstance(args, dict), f'"args" must be a dict, but we got {type(args)}: {args}'
assert isinstance(dyn_args, dict), f'"dyn_args" must be a dict, but we got {type(dyn_args)}: {dyn_args}'
args.update(self._static_args)
dyn_args.update(self._dyn_args)
# time step
if start_t is None:
start_t = self.start_t[0]
end_t = start_t + duration
# times
times = bm.arange(start_t, end_t, self.dt).value
indices = bm.arange(times.size).value + self.idx.value
_dyn_args, _ = tree_flatten(dyn_args)
for _d in _dyn_args:
if jnp.shape(_d)[0] != times.size:
raise ValueError(f'The shape of `dyn_args` does not match the given duration. '
f'{jnp.shape(_d)[0]} != {times.size} (duration={duration}, dt={self.dt}).')
del _d
del _dyn_args
# running
if self.progress_bar:
self._pbar = tqdm.auto.tqdm(total=times.size)
self._pbar.set_description(f"Running a duration of {round(float(duration), 3)} ({times.size} steps)",
refresh=True)
if eval_time:
t0 = time.time()
hists = self._run_fun_integration(args, dyn_args, times, indices)
if eval_time:
running_time = time.time() - t0
if self.progress_bar:
self._pbar.close()
# post-running
times += self.dt
if self.numpy_mon_after_run:
times = np.asarray(times)
for key in list(hists.keys()):
hists[key] = np.asarray(hists[key])
self.mon.ts = times
for key in hists.keys():
self.mon[key] = hists[key]
self.start_t[0] = end_t
self.idx[0] += times.size
if eval_time:
return running_time