Source code for brainpy._src.dyn.others.input

# -*- coding: utf-8 -*-
import warnings
from functools import partial
from typing import Union, Sequence, Any, Optional, Callable

import jax
import jax.numpy as jnp

from brainpy import math as bm
from brainpy._src.context import share
from brainpy._src.dyn.utils import get_spk_type
from brainpy._src.dyn.base import NeuDyn
from brainpy._src.initialize import parameter, variable_
from brainpy._src.mixin import ReturnInfo
from brainpy.types import Shape, ArrayType

__all__ = [
  'InputGroup',
  'OutputGroup',
  'SpikeTimeGroup',
  'PoissonGroup',
]


[docs] class InputGroup(NeuDyn): """Input neuron group for place holder. Args: size: int, tuple of int keep_size: bool mode: Mode name: str """ def __init__( self, size: Union[int, Sequence[int]], sharding: Any = None, keep_size: bool = False, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super().__init__(name=name, sharding=sharding, size=size, keep_size=keep_size, mode=mode)
[docs] def update(self, x): return x
def return_info(self): return ReturnInfo(self.varshape, self.sharding, self.mode, bm.zeros) def reset_state(self, batch_or_mode=None, **kwargs): pass
[docs] class OutputGroup(NeuDyn): """Output neuron group for place holder. Args: size: int, tuple of int keep_size: bool mode: Mode name: str """ def __init__( self, size: Union[int, Sequence[int]], sharding: Any = None, keep_size: bool = False, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super().__init__(name=name, sharding=sharding, size=size, keep_size=keep_size, mode=mode)
[docs] def update(self, x): return x
def return_info(self): return ReturnInfo(self.varshape, self.sharding, self.mode, bm.zeros) def reset_state(self, batch_size=None, **kwargs): pass
[docs] class SpikeTimeGroup(NeuDyn): """The input neuron group characterized by spikes emitting at given times. >>> # Get 2 neurons, firing spikes at 10 ms and 20 ms. >>> SpikeTimeGroup(2, times=[10, 20]) >>> # or >>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms. >>> SpikeTimeGroup(2, times=[10, 20], indices=[0, 0]) >>> # or >>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms. >>> SpikeTimeGroup(2, times=[10, 20, 30], indices=[0, 1, 0]) >>> # or >>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire; >>> # at 30 ms, neuron 1 fires. >>> SpikeTimeGroup(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1]) Parameters ---------- size : int, tuple, list The neuron group geometry. indices : list, tuple, ArrayType The neuron indices at each time point to emit spikes. times : list, tuple, ArrayType The time points which generate the spikes. name : str, optional The name of the dynamic system. """ def __init__( self, size: Union[int, Sequence[int]], indices: Union[Sequence, ArrayType], times: Union[Sequence, ArrayType], spk_type: Optional[type] = None, name: Optional[str] = None, sharding: Optional[Sequence[str]] = None, keep_size: bool = False, mode: Optional[bm.Mode] = None, need_sort: bool = True, ): super().__init__(size=size, sharding=sharding, name=name, keep_size=keep_size, mode=mode) # parameters if keep_size: raise NotImplementedError(f'Do not support keep_size=True in {self.__class__.__name__}') if len(indices) != len(times): raise ValueError(f'The length of "indices" and "times" must be the same. ' f'However, we got {len(indices)} != {len(times)}.') self.num_times = len(times) self.spk_type = get_spk_type(spk_type, self.mode) # data about times and indices self.times = bm.asarray(times) self.indices = bm.asarray(indices, dtype=bm.int_) if need_sort: sort_idx = bm.argsort(self.times) self.indices.value = self.indices[sort_idx] self.times.value = self.times[sort_idx] # variables self.reset_state(self.mode) def reset_state(self, batch_size=None, **kwargs): self.i = bm.Variable(bm.asarray(0)) self.spike = variable_(partial(jnp.zeros, dtype=self.spk_type), self.varshape, batch_size, axis_names=self.sharding, batch_axis_name=bm.sharding.BATCH_AXIS)
[docs] def update(self): # self.spike.value = bm.sharding.partition(bm.zeros_like(self.spike), self.spike.sharding) self.spike.value = bm.zeros_like(self.spike) bm.while_loop(self._body_fun, self._cond_fun, ()) return self.spike.value
def return_info(self): return self.spike # functions def _cond_fun(self): i = self.i.value return bm.logical_and(i < self.num_times, share['t'] >= self.times[i]) def _body_fun(self): i = self.i.value if isinstance(self.mode, bm.BatchingMode): self.spike[:, self.indices[i]] = True else: self.spike[self.indices[i]] = True self.i += 1
[docs] class PoissonGroup(NeuDyn): """Poisson Neuron Group. """ def __init__( self, size: Shape, freqs: Union[int, float, jax.Array, bm.Array, Callable], keep_size: bool = False, sharding: Optional[Sequence[str]] = None, spk_type: Optional[type] = None, name: Optional[str] = None, mode: Optional[bm.Mode] = None, seed=None, ): super().__init__(size=size, sharding=sharding, name=name, keep_size=keep_size, mode=mode) if seed is not None: warnings.warn('') # parameters self.freqs = parameter(freqs, self.num, allow_none=False) self.spk_type = get_spk_type(spk_type, self.mode) # variables self.reset_state(self.mode)
[docs] def update(self): spikes = bm.random.rand_like(self.spike) <= (self.freqs * share['dt'] / 1000.) spikes = bm.asarray(spikes, dtype=self.spk_type) # spikes = bm.sharding.partition(spikes, self.spike.sharding) self.spike.value = spikes return spikes
def return_info(self): return self.spike def reset_state(self, batch_or_mode=None, **kwargs): self.spike = self.init_variable(partial(jnp.zeros, dtype=self.spk_type), batch_or_mode)