Source code for brainpy.math.delayvars

# -*- coding: utf-8 -*-

from typing import Union, Callable, Tuple

import jax.numpy as jnp
from jax import vmap
from jax.lax import cond, stop_gradient

from brainpy import check
from brainpy.base.base import Base
from brainpy.errors import UnsupportedError
from brainpy.math import numpy_ops as bm
from brainpy.math.jaxarray import ndarray, Variable, JaxArray
from brainpy.math.setting import get_dt
from brainpy.tools.checking import check_float, check_integer
from brainpy.tools.errors import check_error_in_jit

__all__ = [
  'AbstractDelay',
  'TimeDelay', 'LengthDelay',
  'NeuTimeDelay', 'NeuLenDelay',
]


[docs]class AbstractDelay(Base): def update(self, *args, **kwargs): raise NotImplementedError
_FUNC_BEFORE = 'function' _DATA_BEFORE = 'data' _INTERP_LINEAR = 'linear_interp' _INTERP_ROUND = 'round'
[docs]class TimeDelay(AbstractDelay): """Delay variable which has a fixed delay time length. For example, we create a delay variable which has a maximum delay length of 1 ms >>> import brainpy.math as bm >>> delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1) >>> delay(-0.5) [-0. -0. -0.] This function supports multiple dimensions of the tensor. For example, 1. the one-dimensional delay data >>> delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1, before_t0=lambda t: t) >>> delay(-0.2) [-0.2 -0.2 -0.2] 2. the two-dimensional delay data >>> delay = bm.TimeDelay(bm.zeros((3, 2)), delay_len=1., dt=0.1, before_t0=lambda t: t) >>> delay(-0.6) [[-0.6 -0.6] [-0.6 -0.6] [-0.6 -0.6]] 3. the three-dimensional delay data >>> delay = bm.TimeDelay(bm.zeros((3, 2, 1)), delay_len=1., dt=0.1, before_t0=lambda t: t) >>> delay(-0.8) [[[-0.8] [-0.8]] [[-0.8] [-0.8]] [[-0.8] [-0.8]]] Parameters ---------- delay_target: JaxArray, ndarray, Variable The initial delay data. t0: float, int The zero time. delay_len: float, int The maximum delay length. dt: float, int The time precesion. before_t0: callable, bm.ndarray, jnp.ndarray, float, int The delay data before ::math`t_0`. - when `before_t0` is a function, it should receive a time argument `t` - when `before_to` is a tensor, it should be a tensor with shape of :math:`(num\_delay, ...)`, where the longest delay data is aranged in the first index. name: str The delay instance name. interp_method: str The way to deal with the delay at the time which is not integer times of the time step. For exameple, if the time step ``dt=0.1``, the time delay length ``delay\_len=1.``, when users require the delay data at ``t-0.53``, we can deal this situation with the following methods: - ``"linear_interp"``: using linear interpolation to get the delay value at the required time (default). - ``"round"``: round the time to make it is the integer times of the time step. For the above situation, we will use the time at ``t-0.5`` to approximate the delay data at ``t-0.53``. .. versionadded:: 2.1.1 See Also -------- LengthDelay """
[docs] def __init__( self, delay_target: Union[ndarray, jnp.ndarray], delay_len: Union[float, int], before_t0: Union[Callable, ndarray, jnp.ndarray, float, int] = None, t0: Union[float, int] = 0., dt: Union[float, int] = None, name: str = None, interp_method: str = 'linear_interp', ): super(TimeDelay, self).__init__(name=name) # shape if not isinstance(delay_target, (jnp.ndarray, JaxArray)): raise ValueError(f'Must be an instance of JaxArray or jax.numpy.ndarray. But we got {type(delay_target)}') # delay_len self.t0 = t0 self.dt = get_dt() if dt is None else dt check_float(delay_len, 'delay_len', allow_none=False, allow_int=True, min_bound=0.) self.delay_len = delay_len self.num_delay_step = int(jnp.ceil(self.delay_len / self.dt)) + 1 # interp method if interp_method not in [_INTERP_LINEAR, _INTERP_ROUND]: raise UnsupportedError(f'Un-supported interpolation method {interp_method}, ' f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}') self.interp_method = interp_method # time variables self.idx = Variable(jnp.asarray([0])) check_float(t0, 't0', allow_none=False, allow_int=True, ) self.current_time = Variable(jnp.asarray([t0])) # delay data batch_axis = None if hasattr(delay_target, 'batch_axis') and (delay_target.batch_axis is not None): batch_axis = delay_target.batch_axis + 1 self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype), batch_axis=batch_axis) if before_t0 is None: self._before_type = _DATA_BEFORE elif callable(before_t0): self._before_t0 = lambda t: bm.asarray(bm.broadcast_to(before_t0(t), delay_target.shape), dtype=delay_target.dtype).value self._before_type = _FUNC_BEFORE elif isinstance(before_t0, (ndarray, jnp.ndarray, float, int)): self._before_type = _DATA_BEFORE self.data[:-1] = before_t0 else: raise ValueError(f'"before_t0" does not support {type(before_t0)}') # set initial data self.data[-1] = delay_target # interpolation function self._interp_fun = jnp.interp for dim in range(1, delay_target.ndim + 1, 1): self._interp_fun = vmap(self._interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
def reset(self, delay_target, delay_len, t0: Union[float, int] = 0., before_t0=None): """Reset the delay variable. Parameters ---------- delay_target: JaxArray, ndarray, Variable The delay target. delay_len: float, int The maximum delay length. The unit is the time. t0: int, float The zero time. before_t0: int, float, ndarray, JaxArray The data before t0. """ self.delay_len = delay_len self.num_delay_step = int(jnp.ceil(self.delay_len / self.dt)) + 1 self.data.value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) self.data[-1] = delay_target self.idx = Variable(jnp.asarray([0])) self.current_time = Variable(jnp.asarray([t0])) if before_t0 is not None: if not isinstance(before_t0, (ndarray, jnp.ndarray, float, int)): raise ValueError('Only support numerical values.') self.data[:-1] = before_t0 self._before_type = _DATA_BEFORE def _check_time1(self, times): prev_time, current_time = times raise ValueError(f'The request time should be less than the ' f'current time {current_time}. But we ' f'got {prev_time} > {current_time}') def _check_time2(self, times): prev_time, current_time = times raise ValueError(f'The request time of the variable should be in ' f'[{current_time - self.delay_len}, {current_time}], ' f'but we got {prev_time}') def __call__(self, time, indices=None): # check if check.is_checking(): current_time = self.current_time[0] check_error_in_jit(time > current_time + 1e-6, self._check_time1, (time, current_time)) check_error_in_jit(time < current_time - self.delay_len - self.dt, self._check_time2, (time, current_time)) if self._before_type == _FUNC_BEFORE: res = cond(time < self.t0, self._before_t0, self._after_t0, time) else: res = self._after_t0(time) if indices is not None: # TODO: indices is highly inefficient res = res[indices] return res def _after_t0(self, prev_time): diff = self.delay_len - (self.current_time[0] - prev_time) if isinstance(diff, ndarray): diff = diff.value if self.interp_method == _INTERP_LINEAR: req_num_step = jnp.asarray(diff / self.dt, dtype=jnp.int32) extra = diff - req_num_step * self.dt return cond(extra == 0., self._true_fn, self._false_fn, (req_num_step, extra)) elif self.interp_method == _INTERP_ROUND: req_num_step = jnp.asarray(jnp.round(diff / self.dt), dtype=jnp.int32) return self._true_fn([req_num_step, 0.]) else: raise UnsupportedError(f'Un-supported interpolation method {self.interp_method}, ' f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}') def _true_fn(self, div_mod): req_num_step, extra = div_mod return self.data[self.idx[0] + req_num_step] def _false_fn(self, div_mod): req_num_step, extra = div_mod idx = jnp.asarray([self.idx[0] + req_num_step, self.idx[0] + req_num_step + 1]) idx %= self.num_delay_step return self._interp_fun(extra, jnp.asarray([0., self.dt]), self.data[idx]) def update(self, time, value): self.data[self.idx[0]] = value self.current_time[0] = time self.idx.value = (self.idx + 1) % self.num_delay_step
[docs]class NeuTimeDelay(TimeDelay): """Neutral Time Delay. Alias of :py:class:`~.TimeDelay`.""" pass
[docs]class LengthDelay(AbstractDelay): """Delay variable which has a fixed delay length. Parameters ---------- delay_target: int, sequence of int The initial delay data. delay_len: int The maximum delay length. initial_delay_data: Array The delay data. name: str The delay object name. See Also -------- TimeDelay """
[docs] def __init__( self, delay_target: Union[ndarray, jnp.ndarray], delay_len: int, initial_delay_data: Union[float, int, bool, ndarray, jnp.ndarray, Callable] = None, name: str = None, batch_axis: int = None, ): super(LengthDelay, self).__init__(name=name) # attributes and variables self.num_delay_step: int = None self.idx: Variable = None self.data: Variable = None # initialization self.reset(delay_target, delay_len, initial_delay_data, batch_axis)
def reset( self, delay_target, delay_len=None, initial_delay_data=None, batch_axis=None ): if not isinstance(delay_target, (ndarray, jnp.ndarray)): raise ValueError(f'Must be an instance of brainpy.math.ndarray ' f'or jax.numpy.ndarray. But we got {type(delay_target)}') # delay_len check_integer(delay_len, 'delay_len', allow_none=True, min_bound=0) if delay_len is None: if self.num_delay_step is None: raise ValueError('"delay_len" cannot be None.') delay_len = self.num_delay_step - 1 self.num_delay_step = delay_len + 1 # time variables if self.idx is None: self.idx = Variable(jnp.asarray([0], dtype=jnp.int32)) else: self.idx.value = jnp.asarray([0], dtype=jnp.int32) # delay data if self.data is None: if batch_axis is None: if isinstance(delay_target, Variable) and (delay_target.batch_axis is not None): batch_axis = delay_target.batch_axis + 1 self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype), batch_axis=batch_axis) else: self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) self.data[-1] = delay_target if initial_delay_data is None: pass elif isinstance(initial_delay_data, (ndarray, jnp.ndarray, float, int, bool)): self.data[:-1] = initial_delay_data elif callable(initial_delay_data): self.data[:-1] = initial_delay_data((delay_len,) + delay_target.shape, dtype=delay_target.dtype) else: raise ValueError(f'"delay_data" does not support {type(initial_delay_data)}') def _check_delay(self, delay_len): raise ValueError(f'The request delay length should be less than the ' f'maximum delay {self.num_delay_step}. ' f'But we got {delay_len}') def __call__(self, delay_len, *indices): # check if check.is_checking(): check_error_in_jit(bm.any(delay_len >= self.num_delay_step), self._check_delay, delay_len) # the delay length delay_idx = (self.idx[0] - delay_len - 1) % self.num_delay_step delay_idx = stop_gradient(delay_idx) if not jnp.issubdtype(delay_idx.dtype, jnp.integer): raise ValueError(f'"delay_len" must be integer, but we got {delay_len}') # the delay data indices = (delay_idx,) + tuple(indices) return self.data[indices] def update(self, value: Union[float, JaxArray, jnp.DeviceArray]): idx = stop_gradient(self.idx[0]) self.data[idx] = value self.idx.value = stop_gradient((self.idx + 1) % self.num_delay_step)
[docs]class NeuLenDelay(LengthDelay): """Neutral Length Delay. Alias of :py:class:`~.LengthDelay`.""" pass