# -*- coding: utf-8 -*-
import functools
from typing import Union, Optional, Dict, Sequence
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_unflatten, tree_map
from brainpy import tools, math as bm
from brainpy._src.context import share
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.helpers import clear_input
from brainpy.check import is_float, is_integer
from brainpy.types import PyTree
__all__ = [
'LoopOverTime',
]
[docs]
class LoopOverTime(DynamicalSystem):
"""Transform a single step :py:class:`~.DynamicalSystem`
into a multiple-step forward propagation :py:class:`~.BrainPyObject`.
.. note::
This object transforms a :py:class:`~.DynamicalSystem` into a :py:class:`~.BrainPyObject`.
If the `target` has a batching mode, before sending the data into the wrapped object,
reset the state (``.reset_state(batch_size)``) with the same batch size as in the given data.
For more flexible customization, we recommend users to use :py:func:`~.for_loop`,
or :py:class:`~.DSRunner`.
Examples
--------
This model can be used for network training:
>>> import brainpy as bp
>>> import brainpy.math as bm
>>>
>>> n_time, n_batch, n_in = 30, 128, 100
>>> model = bp.Sequential(l1=bp.layers.RNNCell(n_in, 20),
>>> l2=bm.relu,
>>> l3=bp.layers.RNNCell(20, 2))
>>> over_time = bp.LoopOverTime(model, data_first_axis='T')
>>> over_time.reset_state(n_batch)
(30, 128, 2)
>>>
>>> hist_l3 = over_time(bm.random.rand(n_time, n_batch, n_in))
>>> print(hist_l3.shape)
>>>
>>> # monitor the "l1" layer state
>>> over_time = bp.LoopOverTime(model, out_vars=model['l1'].state, data_first_axis='T')
>>> over_time.reset_state(n_batch)
>>> hist_l3, hist_l1 = over_time(bm.random.rand(n_time, n_batch, n_in))
>>> print(hist_l3.shape)
(30, 128, 2)
>>> print(hist_l1.shape)
(30, 128, 20)
It is also able to used in brain simulation models:
.. plot::
:include-source: True
>>> import brainpy as bp
>>> import brainpy.math as bm
>>> import matplotlib.pyplot as plt
>>>
>>> hh = bp.neurons.HH(1)
>>> over_time = bp.LoopOverTime(hh, out_vars=hh.V)
>>>
>>> # running with a given duration
>>> _, potentials = over_time(100.)
>>> plt.plot(bm.as_numpy(potentials), label='with given duration')
>>>
>>> # running with the given inputs
>>> _, potentials = over_time(bm.ones(1000) * 5)
>>> plt.plot(bm.as_numpy(potentials), label='with given inputs')
>>> plt.legend()
>>> plt.show()
Parameters
----------
target: DynamicalSystem
The target to transform.
no_state: bool
Denoting whether the `target` has the shared argument or not.
- For ANN layers which are no_state, like :py:class:`~.Dense` or :py:class:`~.Conv2d`,
set `no_state=True` is high efficiently. This is because :math:`Y[t]` only relies on
:math:`X[t]`, and it is not necessary to calculate :math:`Y[t]` step-bt-step.
For this case, we reshape the input from `shape = [T, N, *]` to `shape = [TN, *]`,
send data to the object, and reshape output to `shape = [T, N, *]`.
In this way, the calculation over different time is parralelized.
out_vars: PyTree
The variables to monitor over the time loop.
t0: float, optional
The start time to run the system. If None, ``t`` will be no longer generated in the loop.
i0: int, optional
The start index to run the system. If None, ``i`` will be no longer generated in the loop.
dt: float
The time step.
shared_arg: dict
The shared arguments across the nodes.
For instance, `shared_arg={'fit': False}` for the prediction phase.
data_first_axis: str
Denoting the type of the first axis of input data.
If ``'T'``, we treat the data as `(time, ...)`.
If ``'B'``, we treat the data as `(batch, time, ...)` when the `target` is in Batching mode.
Default is ``'T'``.
name: str
The transformed object name.
"""
def __init__(
self,
target: DynamicalSystem,
out_vars: Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]] = None,
no_state: bool = False,
t0: Optional[float] = 0.,
i0: Optional[int] = 0,
dt: Optional[float] = None,
shared_arg: Optional[Dict] = None,
data_first_axis: str = 'T',
name: str = None,
jit: bool = True,
remat: bool = False,
):
super().__init__(name=name)
assert data_first_axis in ['B', 'T']
is_integer(i0, 'i0', allow_none=True)
is_float(t0, 't0', allow_none=True)
is_float(dt, 'dt', allow_none=True)
dt = share.dt if dt is None else dt
if shared_arg is None:
shared_arg = dict(dt=dt)
else:
assert isinstance(shared_arg, dict)
shared_arg['dt'] = dt
self.dt = dt
self._t0 = t0
self._i0 = i0
self.t0 = None if t0 is None else bm.Variable(bm.as_jax(t0))
self.i0 = None if i0 is None else bm.Variable(bm.as_jax(i0))
self.jit = jit
self.remat = remat
self.shared_arg = shared_arg
self.data_first_axis = data_first_axis
self.target = target
if not isinstance(target, DynamicalSystem):
raise TypeError(f'Must be instance of {DynamicalSystem.__name__}, '
f'but we got {type(target)}')
self.no_state = no_state
self.out_vars = out_vars
if out_vars is not None:
out_vars, _ = tree_flatten(out_vars, is_leaf=lambda s: isinstance(s, bm.Variable))
for v in out_vars:
if not isinstance(v, bm.Variable):
raise TypeError('out_vars must be a PyTree of Variable.')
def __call__(
self,
duration_or_xs: Union[float, PyTree],
):
"""Forward propagation along the time or inputs.
Parameters
----------
duration_or_xs: float, PyTree
If `float`, it indicates a running duration.
If a PyTree, it is the given inputs.
Returns
-------
out: PyTree
The accumulated outputs over time.
"""
# inputs
if isinstance(duration_or_xs, float):
shared = tools.DotDict()
if self.t0 is not None:
shared['t'] = jnp.arange(0, duration_or_xs, self.dt) + self.t0.value
if self.i0 is not None:
shared['i'] = jnp.arange(0, shared['t'].shape[0]) + self.i0.value
xs = None
if self.no_state:
raise ValueError('Under the `no_state=True` setting, input cannot be a duration.')
length = shared['t'].shape
else:
inp_err_msg = ('\n'
'Input should be a Array PyTree with the shape '
'of (B, T, ...) or (T, B, ...) with `data_first_axis="T"`, '
'where B the batch size and T the time length.')
xs, tree = tree_flatten(duration_or_xs, lambda a: isinstance(a, bm.Array))
if self.target.mode.is_child_of(bm.BatchingMode):
b_idx, t_idx = (1, 0) if self.data_first_axis == 'T' else (0, 1)
try:
batch = tuple(set([x.shape[b_idx] for x in xs]))
except (AttributeError, IndexError) as e:
raise ValueError(inp_err_msg) from e
if len(batch) != 1:
raise ValueError('\n'
'Input should be a Array PyTree with the same batch dimension. '
f'but we got {tree_unflatten(tree, batch)}.')
try:
length = tuple(set([x.shape[t_idx] for x in xs]))
except (AttributeError, IndexError) as e:
raise ValueError(inp_err_msg) from e
if len(batch) != 1:
raise ValueError('\n'
'Input should be a Array PyTree with the same batch size. '
f'but we got {tree_unflatten(tree, batch)}.')
if len(length) != 1:
raise ValueError('\n'
'Input should be a Array PyTree with the same time length. '
f'but we got {tree_unflatten(tree, length)}.')
if self.no_state:
xs = [bm.reshape(x, (length[0] * batch[0],) + x.shape[2:]) for x in xs]
else:
if self.data_first_axis == 'B':
xs = [jnp.moveaxis(x, 0, 1) for x in xs]
xs = tree_unflatten(tree, xs)
origin_shape = (length[0], batch[0]) if self.data_first_axis == 'T' else (batch[0], length[0])
else:
try:
length = tuple(set([x.shape[0] for x in xs]))
except (AttributeError, IndexError) as e:
raise ValueError(inp_err_msg) from e
if len(length) != 1:
raise ValueError('\n'
'Input should be a Array PyTree with the same time length. '
f'but we got {tree_unflatten(tree, length)}.')
xs = tree_unflatten(tree, xs)
origin_shape = (length[0],)
# computation
if self.no_state:
share.save(**self.shared_arg)
outputs = self._run(self.shared_arg, dict(), xs)
results = tree_map(lambda a: jnp.reshape(a, origin_shape + a.shape[1:]), outputs)
if self.i0 is not None:
self.i0 += length[0]
if self.t0 is not None:
self.t0 += length[0] * self.dt
return results
else:
shared = tools.DotDict()
if self.t0 is not None:
shared['t'] = jnp.arange(0, self.dt * length[0], self.dt) + self.t0.value
if self.i0 is not None:
shared['i'] = jnp.arange(0, length[0]) + self.i0.value
assert not self.no_state
results = bm.for_loop(functools.partial(self._run, self.shared_arg),
(shared, xs),
jit=self.jit,
remat=self.remat)
if self.i0 is not None:
self.i0 += length[0]
if self.t0 is not None:
self.t0 += length[0] * self.dt
return results
def reset_state(self, batch_size=None):
if self.i0 is not None:
self.i0.value = bm.as_jax(self._i0)
if self.t0 is not None:
self.t0.value = bm.as_jax(self._t0)
def _run(self, static_sh, dyn_sh, x):
share.save(**static_sh, **dyn_sh)
outs = self.target(x)
if self.out_vars is not None:
outs = (outs, tree_map(bm.as_jax, self.out_vars))
clear_input(self.target)
return outs