# -*- coding: utf-8 -*-

import numbers
from typing import Optional, Union, Sequence, Tuple, Callable

import jax.numpy as jnp
from jax import vmap

import brainpy.math as bm
from brainpy._src.dynsys import Projection
from brainpy._src.initialize import Initializer
from brainpy.check import is_sequence
from brainpy.types import ArrayType

class DelayCoupling(Projection):
  """Delay coupling.

  delay_var: Variable
    The delay variable.
  var_to_output: Variable, sequence of Variable
    The target variables to output.
  conn_mat: ArrayType
    The connection matrix.
  required_shape: sequence of int
    The required shape of `(pre, post)`.
  delay_steps: int, ArrayType
    The matrix of delay time steps. Must be int.
  initial_delay_data: Initializer, Callable
    The initializer of the initial delay data.

  def __init__(
      delay_var: bm.Variable,
      var_to_output: Union[bm.Variable, Sequence[bm.Variable]],
      conn_mat: ArrayType,
      required_shape: Tuple[int, ...],
      delay_steps: Optional[Union[int, ArrayType, Callable]] = None,
      initial_delay_data: Union[Callable, ArrayType, numbers.Number] = None,
      name: Optional[str] = None,
      mode: Optional[bm.Mode] = None,
    super().__init__(name=name, mode=mode)

    # delay variable
    if not isinstance(delay_var, bm.Variable):
      raise ValueError(f'"delay_var" must be an instance of brainpy.math.Variable. '
                       f'But we got {type(delay_var)}')
    self.delay_var = delay_var

    # output variables
    if isinstance(var_to_output, bm.Variable):
      var_to_output = [var_to_output]
    is_sequence(var_to_output, 'output_var', elem_type=bm.Variable, allow_none=False)
    self.output_var = var_to_output

    # Connection matrix
    self.conn_mat = conn_mat
    if self.conn_mat.shape != required_shape:
      raise ValueError(f'we expect the structural connection matrix has the shape of '
                       f'(pre.num, post.num), i.e., {required_shape}, '
                       f'while we got {self.conn_mat.shape}.')

    # Delay matrix
    if delay_steps is None:
      self.delay_steps = None
      self.delay_type = 'none'
      num_delay_step = None
    elif callable(delay_steps):
      delay_steps = delay_steps(required_shape)
      if delay_steps.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]:
        raise ValueError(f'"delay_steps" must be integer typed. But we got {delay_steps.dtype}')
      self.delay_steps = delay_steps
      self.delay_type = 'array'
      num_delay_step = self.delay_steps.max()
    elif isinstance(delay_steps, (bm.Array, jnp.ndarray)):
      if delay_steps.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]:
        raise ValueError(f'"delay_steps" must be integer typed. But we got {delay_steps.dtype}')
      if delay_steps.ndim == 0:
        self.delay_type = 'int'
        self.delay_type = 'array'
        if delay_steps.shape != required_shape:
          raise ValueError(f'we expect the delay matrix has the shape of '
                           f'(pre.num, post.num), i.e., {required_shape}. '
                           f'While we got {delay_steps.shape}.')
      self.delay_steps = delay_steps
      num_delay_step = self.delay_steps.max()
    elif isinstance(delay_steps, int):
      self.delay_steps = delay_steps
      num_delay_step = delay_steps
      self.delay_type = 'int'
      raise ValueError(f'Unknown type of delay steps: {type(delay_steps)}')

    # delay variables
    _ = self.register_delay(f'delay_{id(delay_var)}',

  def reset_state(self, batch_size=None):

[docs] class DiffusiveCoupling(DelayCoupling): """Diffusive coupling. This class simulates the model of:: coupling = g * (delayed_coupling_var1 - coupling_var2) target_var += coupling Examples -------- >>> import brainpy as bp >>> from brainpy import rates >>> areas = bp.rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn') >>> conn = bp.synapses.DiffusiveCoupling(areas.x, areas.x, areas.input, >>> conn_mat=Cmat, delay_steps=Dmat, >>> initial_delay_data=bp.init.Uniform(0, 0.05)) >>> net = bp.Network(areas, conn) Parameters ---------- coupling_var1: Variable The first coupling variable, used for delay. coupling_var2: Variable Another coupling variable. var_to_output: Variable, sequence of Variable The target variables to output. conn_mat: ArrayType The connection matrix. delay_steps: int, ArrayType The matrix of delay time steps. Must be int. initial_delay_data: Initializer, Callable The initializer of the initial delay data. name: str The name of the model. """
[docs] def __init__( self, coupling_var1: bm.Variable, coupling_var2: bm.Variable, var_to_output: Union[bm.Variable, Sequence[bm.Variable]], conn_mat: ArrayType, delay_steps: Optional[Union[int, ArrayType, Initializer, Callable]] = None, initial_delay_data: Union[Initializer, Callable, ArrayType, float, int, bool] = None, name: str = None, mode: bm.Mode = None, ): if not isinstance(coupling_var1, bm.Variable): raise ValueError(f'"coupling_var1" must be an instance of brainpy.math.Variable. ' f'But we got {type(coupling_var1)}') if not isinstance(coupling_var2, bm.Variable): raise ValueError(f'"coupling_var2" must be an instance of brainpy.math.Variable. ' f'But we got {type(coupling_var2)}') if jnp.ndim(coupling_var1) != 1: raise ValueError(f'Only support 1d vector of coupling variable. ' f'But we got {jnp.ndim(coupling_var1)}') if jnp.ndim(coupling_var2) != 1: raise ValueError(f'Only support 1d vector of coupling variable. ' f'But we got {jnp.ndim(coupling_var2)}') super().__init__( delay_var=coupling_var1, var_to_output=var_to_output, conn_mat=conn_mat, required_shape=(coupling_var1.size, coupling_var2.size), delay_steps=delay_steps, initial_delay_data=initial_delay_data, name=name, mode=mode, ) self.coupling_var1 = coupling_var1 self.coupling_var2 = coupling_var2
[docs] def update(self): # delays axis = self.coupling_var1.ndim delay_var = self.get_delay_var(f'delay_{id(self.delay_var)}') if self.delay_steps is None: diffusive = (jnp.expand_dims(self.coupling_var1.value, axis=axis) - jnp.expand_dims(self.coupling_var2.value, axis=axis - 1)) diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1) elif self.delay_type == 'array': if isinstance(self.mode, bm.TrainingMode): indices = (slice(None, None, None), jnp.arange(self.coupling_var1.size),) else: indices = (jnp.arange(self.coupling_var1.size),) f = vmap(lambda steps: delay_var.retrieve(steps, *indices), in_axes=1) # (..., pre.num) delays = f(self.delay_steps) # (..., post.num, pre.num) diffusive = (jnp.moveaxis(bm.as_jax(delays), axis - 1, axis) - jnp.expand_dims(self.coupling_var2.value, axis=axis - 1)) # (..., pre.num, post.num) diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1) elif self.delay_type == 'int': delayed_data = delay_var.retrieve(self.delay_steps) # (..., pre.num) diffusive = (jnp.expand_dims(delayed_data, axis=axis) - jnp.expand_dims(self.coupling_var2.value, axis=axis - 1)) # (..., pre.num, post.num) diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1) else: raise ValueError(f'Unknown delay type {self.delay_type}') # output to target variable for target in self.output_var: target.value += diffusive
[docs] class AdditiveCoupling(DelayCoupling): """Additive coupling. This class simulates the model of:: coupling = g * delayed_coupling_var target_var += coupling Parameters ---------- coupling_var: Variable The coupling variable, used for delay. var_to_output: Variable, sequence of Variable The target variables to output. conn_mat: ArrayType The connection matrix. delay_steps: int, ArrayType The matrix of delay time steps. Must be int. initial_delay_data: Initializer, Callable The initializer of the initial delay data. name: str The name of the model. """
[docs] def __init__( self, coupling_var: bm.Variable, var_to_output: Union[bm.Variable, Sequence[bm.Variable]], conn_mat: ArrayType, delay_steps: Optional[Union[int, ArrayType, Initializer, Callable]] = None, initial_delay_data: Union[Initializer, Callable, ArrayType, float, int, bool] = None, name: str = None, mode: bm.Mode = None, ): if not isinstance(coupling_var, bm.Variable): raise ValueError(f'"coupling_var" must be an instance of brainpy.math.Variable. ' f'But we got {type(coupling_var)}') if jnp.ndim(coupling_var) != 1: raise ValueError(f'Only support 1d vector of coupling variable. ' f'But we got {jnp.ndim(coupling_var)}') super().__init__( delay_var=coupling_var, var_to_output=var_to_output, conn_mat=conn_mat, required_shape=(coupling_var.size, coupling_var.size), delay_steps=delay_steps, initial_delay_data=initial_delay_data, name=name, mode=mode, ) self.coupling_var = coupling_var
[docs] def update(self): # delay function axis = self.coupling_var.ndim delay_var = self.get_delay_var(f'delay_{id(self.delay_var)}') if self.delay_steps is None: additive = self.coupling_var @ self.conn_mat elif self.delay_type == 'array': if isinstance(self.mode, bm.TrainingMode): indices = (slice(None, None, None), jnp.arange(self.coupling_var.size),) else: indices = (jnp.arange(self.coupling_var.size),) f = vmap(lambda steps: delay_var.retrieve(steps, *indices), in_axes=1) # (.., pre.num,) delays = f(self.delay_steps) # (..., post.num, pre.num) additive = (self.conn_mat * jnp.moveaxis(delays, axis - 1, axis)).sum(axis=axis - 1) elif self.delay_type == 'int': delayed_var = delay_var.retrieve(self.delay_steps) # (..., pre.num) additive = delayed_var @ self.conn_mat else: raise ValueError # output to target variable for target in self.output_var: target.value += additive