Source code for brainpy._src.math.delayvars

# -*- coding: utf-8 -*-

from typing import Union, Callable
import numbers

import jax
import jax.numpy as jnp
from jax import vmap
from jax.lax import stop_gradient

from brainpy import check
from brainpy.check import is_float, is_integer, jit_error
from brainpy.errors import UnsupportedError
from .compat_numpy import broadcast_to, expand_dims, concatenate
from .environment import get_dt, get_float
from .interoperability import as_jax
from .ndarray import ndarray, Array
from .object_transform.base import BrainPyObject
from .object_transform.controls import cond
from .object_transform.variables import Variable

__all__ = [
  'AbstractDelay',
  'TimeDelay', 'LengthDelay',
  'NeuTimeDelay', 'NeuLenDelay',
  'ROTATE_UPDATE',
  'CONCAT_UPDATE',
]


def _as_jax_array(arr):
  return arr.value if isinstance(arr, Array) else arr


class AbstractDelay(BrainPyObject):
  pass


_FUNC_BEFORE = 'function'
_DATA_BEFORE = 'data'
_INTERP_LINEAR = 'linear_interp'
_INTERP_ROUND = 'round'


[docs] class TimeDelay(AbstractDelay): r"""Delay variable which has a fixed delay time length. For example, we create a delay variable which has a maximum delay length of 1 ms >>> import brainpy.math as bm >>> delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1) >>> delay(-0.5) [-0. -0. -0.] This function supports multiple dimensions of the tensor. For example, 1. the one-dimensional delay data >>> delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1, before_t0=lambda t: t) >>> delay(-0.2) [-0.2 -0.2 -0.2] 2. the two-dimensional delay data >>> delay = bm.TimeDelay(bm.zeros((3, 2)), delay_len=1., dt=0.1, before_t0=lambda t: t) >>> delay(-0.6) [[-0.6 -0.6] [-0.6 -0.6] [-0.6 -0.6]] 3. the three-dimensional delay data >>> delay = bm.TimeDelay(bm.zeros((3, 2, 1)), delay_len=1., dt=0.1, before_t0=lambda t: t) >>> delay(-0.8) [[[-0.8] [-0.8]] [[-0.8] [-0.8]] [[-0.8] [-0.8]]] Parameters ---------- delay_target: ArrayType The initial delay data. t0: float, int The zero time. delay_len: float, int The maximum delay length. dt: float, int The time precesion. before_t0: callable, bm.ndarray, jnp.ndarray, float, int The delay data before ::math`t_0`. - when `before_t0` is a function, it should receive a time argument `t` - when `before_to` is a tensor, it should be a tensor with shape of :math:`(num_delay, ...)`, where the longest delay data is aranged in the first index. name: str The delay instance name. interp_method: str The way to deal with the delay at the time which is not integer times of the time step. For exameple, if the time step ``dt=0.1``, the time delay length ``delay\_len=1.``, when users require the delay data at ``t-0.53``, we can deal this situation with the following methods: - ``"linear_interp"``: using linear interpolation to get the delay value at the required time (default). - ``"round"``: round the time to make it is the integer times of the time step. For the above situation, we will use the time at ``t-0.5`` to approximate the delay data at ``t-0.53``. .. versionadded:: 2.1.1 See Also -------- LengthDelay """ def __init__( self, delay_target: Union[ndarray, jnp.ndarray], delay_len: Union[float, int], before_t0: Union[Callable, ndarray, jnp.ndarray, float, int] = None, t0: Union[float, int] = 0., dt: Union[float, int] = None, name: str = None, interp_method: str = 'linear_interp', ): super(TimeDelay, self).__init__(name=name) # shape if not isinstance(delay_target, (jnp.ndarray, Array)): raise ValueError(f'Must be an instance of Array or jax.numpy.ndarray. But we got {type(delay_target)}') # delay_len self.t0 = t0 self.dt = get_dt() if dt is None else dt is_float(delay_len, 'delay_len', allow_none=False, allow_int=True, min_bound=0.) self.delay_len = delay_len self.num_delay_step = int(jnp.ceil(self.delay_len / self.dt)) + 1 # interp method if interp_method not in [_INTERP_LINEAR, _INTERP_ROUND]: raise UnsupportedError(f'Un-supported interpolation method {interp_method}, ' f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}') self.interp_method = interp_method # time variables self.idx = Variable(jnp.asarray([0])) is_float(t0, 't0', allow_none=False, allow_int=True, ) self.current_time = Variable(jnp.asarray([t0], dtype=get_float())) # delay data batch_axis = None if hasattr(delay_target, 'batch_axis') and (delay_target.batch_axis is not None): batch_axis = delay_target.batch_axis + 1 self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype), batch_axis=batch_axis) if before_t0 is None: self._before_type = _DATA_BEFORE elif callable(before_t0): self._before_t0 = lambda t: as_jax(broadcast_to(before_t0(t), delay_target.shape), dtype=delay_target.dtype) self._before_type = _FUNC_BEFORE elif isinstance(before_t0, (ndarray, jnp.ndarray, float, int)): self._before_type = _DATA_BEFORE self.data[:-1] = before_t0 else: raise ValueError(f'"before_t0" does not support {type(before_t0)}') # set initial data self.data[-1] = delay_target # interpolation function self._interp_fun = jnp.interp for dim in range(1, delay_target.ndim + 1, 1): self._interp_fun = vmap(self._interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
[docs] def reset(self, delay_target, delay_len, t0: Union[float, int] = 0., before_t0=None): """Reset the delay variable. Parameters ---------- delay_target: ArrayType The delay target. delay_len: float, int The maximum delay length. The unit is the time. t0: int, float The zero time. before_t0: int, float, ArrayType The data before t0. """ self.delay_len = delay_len self.num_delay_step = int(jnp.ceil(self.delay_len / self.dt)) + 1 self.data.value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) self.data[-1] = delay_target self.idx = Variable(jnp.asarray([0])) self.current_time = Variable(jnp.asarray([t0])) if before_t0 is not None: if not isinstance(before_t0, (ndarray, jnp.ndarray, float, int)): raise ValueError('Only support numerical values.') self.data[:-1] = before_t0 self._before_type = _DATA_BEFORE
def _check_time1(self, times): prev_time, current_time = times raise ValueError(f'The request time should be less than the ' f'current time {current_time}. But we ' f'got {prev_time} > {current_time}') def _check_time2(self, times): prev_time, current_time = times raise ValueError(f'The request time of the variable should be in ' f'[{current_time - self.delay_len}, {current_time}], ' f'but we got {prev_time}') def __call__(self, time, indices=None): # check if check.is_checking(): current_time = self.current_time[0] jit_error(time > current_time + 1e-6, self._check_time1, (time, current_time)) jit_error(time < current_time - self.delay_len - self.dt, self._check_time2, (time, current_time)) if self._before_type == _FUNC_BEFORE: res = cond(time < self.t0, self._before_t0, self._after_t0, time) else: res = self._after_t0(time) if indices is not None: # TODO: indices is highly inefficient res = res[indices] return res def _after_t0(self, prev_time): diff = self.delay_len - (self.current_time[0] - prev_time) if isinstance(diff, ndarray): diff = diff.value if self.interp_method == _INTERP_LINEAR: req_num_step = jnp.asarray(diff / self.dt, dtype=jnp.int32) extra = diff - req_num_step * self.dt return cond(extra == 0., self._true_fn, self._false_fn, (req_num_step, extra)) elif self.interp_method == _INTERP_ROUND: req_num_step = jnp.asarray(jnp.round(diff / self.dt), dtype=jnp.int32) return self._true_fn(req_num_step, 0.) else: raise UnsupportedError(f'Un-supported interpolation method {self.interp_method}, ' f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}') def _true_fn(self, req_num_step, extra): return self.data[self.idx[0] + req_num_step] def _false_fn(self, req_num_step, extra): idx = jnp.asarray([self.idx[0] + req_num_step, self.idx[0] + req_num_step + 1]) idx %= self.num_delay_step return self._interp_fun(extra, jnp.asarray([0., self.dt]), self.data[idx]) def update(self, value): self.data[self.idx[0]] = value self.current_time += self.dt self.idx.value = (self.idx + 1) % self.num_delay_step
[docs] class NeuTimeDelay(TimeDelay): """Neutral Time Delay. Alias of :py:class:`~.TimeDelay`.""" pass
ROTATE_UPDATE = 'rotation' CONCAT_UPDATE = 'concat'
[docs] class LengthDelay(AbstractDelay): """Delay variable which has a fixed delay length. Parameters ---------- delay_target: int, sequence of int The initial delay data. delay_len: int The maximum delay length. initial_delay_data: Any The delay data. It can be a Python number, like float, int, boolean values. It can also be arrays. Or a callable function or instance of ``Connector``. Note that ``initial_delay_data`` should be arranged as the following way:: delay = 1 [ data delay = 2 data ... .... ... .... delay = delay_len-1 data delay = delay_len data ] .. versionchanged:: 2.2.3.2 The data in the previous version of ``LengthDelay`` is:: delay = delay_len [ data delay = delay_len-1 data ... .... ... .... delay = 2 data delay = 1 data ] name: str The delay object name. batch_axis: int The batch axis. If not provided, it will be inferred from the `delay_target`. update_method: str The method used for updating delay. See Also -------- TimeDelay """ def __init__( self, delay_target: Union[ndarray, jax.Array], delay_len: int, initial_delay_data: Union[float, int, bool, ndarray, jax.Array, Callable] = None, name: str = None, batch_axis: int = None, update_method: str = ROTATE_UPDATE ): super(LengthDelay, self).__init__(name=name) assert update_method in [ROTATE_UPDATE, CONCAT_UPDATE] self.update_method = update_method # attributes and variables self.data: Variable = None self.num_delay_step: int = 0 self.idx: Variable = None self.delay_target = None if isinstance(delay_target, Variable): self.delay_target = delay_target # initialization self.reset(delay_target, delay_len, initial_delay_data, batch_axis) @property def delay_shape(self): """The data shape of this delay variable.""" return self.data.shape @property def delay_target_shape(self): """The data shape of the delay target.""" return self.data.shape[1:] def __repr__(self): name = self.__class__.__name__ return (f'{name}(num_delay_step={self.num_delay_step}, ' f'delay_target_shape={self.delay_target_shape}, ' f'update_method={self.update_method})') def reset( self, delay_target, delay_len: int = None, initial_delay_data: Union[float, int, bool, ndarray, jnp.ndarray, Callable] = None, batch_axis: int = None ): if not isinstance(delay_target, (ndarray, jnp.ndarray)): raise ValueError(f'Must be an instance of brainpy.math.ndarray ' f'or jax.numpy.ndarray. But we got {type(delay_target)}') # delay_len is_integer(delay_len, 'delay_len', allow_none=True, min_bound=0) if delay_len is None: if self.num_delay_step is None: raise ValueError('"delay_len" cannot be None.') delay_len = self.num_delay_step - 1 self.num_delay_step = delay_len + 1 # initialize delay data if self.data is None: if batch_axis is None: if isinstance(delay_target, Variable) and (delay_target.batch_axis is not None): batch_axis = delay_target.batch_axis + 1 self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype), batch_axis=batch_axis) else: self.data.value self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) # update delay data self.data[0] = delay_target if initial_delay_data is None: pass elif isinstance(initial_delay_data, (ndarray, jnp.ndarray, float, int, bool)): self.data[1:] = initial_delay_data elif callable(initial_delay_data): self.data[1:] = initial_delay_data((delay_len,) + delay_target.shape, dtype=delay_target.dtype) else: raise ValueError(f'"delay_data" does not support {type(initial_delay_data)}') # time variables if self.update_method == ROTATE_UPDATE: if self.idx is None: self.idx = Variable(stop_gradient(jnp.asarray([0], dtype=jnp.int32))) else: self.idx.value = stop_gradient(jnp.asarray([0], dtype=jnp.int32)) def _check_delay(self, delay_len): raise ValueError(f'The request delay length should be less than the ' f'maximum delay {self.num_delay_step}. ' f'But we got {delay_len}') def __call__(self, delay_len, *indices): return self.retrieve(delay_len, *indices)
[docs] def retrieve(self, delay_len, *indices): """Retrieve the delay data acoording to the delay length. Parameters ---------- delay_len: int, ArrayType The delay length used to retrieve the data. """ if check.is_checking(): jit_error(jnp.any(as_jax(delay_len >= self.num_delay_step)), self._check_delay, delay_len) if self.update_method == ROTATE_UPDATE: delay_idx = (self.idx[0] + delay_len) % self.num_delay_step delay_idx = stop_gradient(delay_idx) elif self.update_method == CONCAT_UPDATE: delay_idx = delay_len else: raise ValueError(f'Unknown updating method "{self.update_method}"') # the delay index if isinstance(delay_idx, int): pass elif hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') indices = (delay_idx,) + tuple(indices) # the delay data return self.data[indices]
[docs] def update(self, value: Union[numbers.Number, Array, jax.Array] = None): """Update delay variable with the new data. Parameters ---------- value: Any The value of the latest data, used to update this delay variable. """ if value is None: if self.delay_target is None: raise ValueError('Must provide value.') else: value = self.delay_target.value if self.update_method == ROTATE_UPDATE: self.idx.value = stop_gradient(as_jax((self.idx - 1) % self.num_delay_step)) self.data[self.idx[0]] = value elif self.update_method == CONCAT_UPDATE: if self.num_delay_step >= 2: self.data.value = concatenate([expand_dims(value, 0), self.data[:-1]], axis=0) else: self.data[:] = value else: raise ValueError(f'Unknown updating method "{self.update_method}"')
[docs] class NeuLenDelay(LengthDelay): """Neutral Length Delay. Alias of :py:class:`~.LengthDelay`.""" pass