Source code for brainpy.dyn.others.input

# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from functools import partial
from typing import Union, Sequence, Any, Optional, Callable

import jax
import jax.numpy as jnp

from brainpy import math as bm
from brainpy.context import share
from brainpy.dyn.base import NeuDyn
from brainpy.dyn.utils import get_spk_type
from brainpy.initialize import parameter, variable_
from brainpy.mixin import ReturnInfo
from brainpy.types import Shape, ArrayType

__all__ = [
    'InputGroup',
    'OutputGroup',
    'SpikeTimeGroup',
    'PoissonGroup',
]


[docs] class InputGroup(NeuDyn): """Input neuron group for place holder. Args: size: int, tuple of int keep_size: bool mode: Mode name: str """ def __init__( self, size: Union[int, Sequence[int]], sharding: Any = None, keep_size: bool = False, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super().__init__(name=name, sharding=sharding, size=size, keep_size=keep_size, mode=mode)
[docs] def update(self, x): return x
def return_info(self): return ReturnInfo(self.varshape, self.sharding, self.mode, bm.zeros) def reset_state(self, batch_or_mode=None, **kwargs): pass
[docs] class OutputGroup(NeuDyn): """Output neuron group for place holder. Args: size: int, tuple of int keep_size: bool mode: Mode name: str """ def __init__( self, size: Union[int, Sequence[int]], sharding: Any = None, keep_size: bool = False, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super().__init__(name=name, sharding=sharding, size=size, keep_size=keep_size, mode=mode)
[docs] def update(self, x): return x
def return_info(self): return ReturnInfo(self.varshape, self.sharding, self.mode, bm.zeros) def reset_state(self, batch_size=None, **kwargs): pass
[docs] class SpikeTimeGroup(NeuDyn): """The input neuron group characterized by spikes emitting at given times. >>> # Get 2 neurons, firing spikes at 10 ms and 20 ms. >>> SpikeTimeGroup(2, times=[10, 20]) >>> # or >>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms. >>> SpikeTimeGroup(2, times=[10, 20], indices=[0, 0]) >>> # or >>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms. >>> SpikeTimeGroup(2, times=[10, 20, 30], indices=[0, 1, 0]) >>> # or >>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire; >>> # at 30 ms, neuron 1 fires. >>> SpikeTimeGroup(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1]) Parameters:: size : int, tuple, list The neuron group geometry. indices : list, tuple, ArrayType The neuron indices at each time point to emit spikes. times : list, tuple, ArrayType The time points which generate the spikes. name : str, optional The name of the dynamic system. """ def __init__( self, size: Union[int, Sequence[int]], indices: Union[Sequence, ArrayType], times: Union[Sequence, ArrayType], spk_type: Optional[type] = None, name: Optional[str] = None, sharding: Optional[Sequence[str]] = None, keep_size: bool = False, mode: Optional[bm.Mode] = None, need_sort: bool = True, ): super().__init__(size=size, sharding=sharding, name=name, keep_size=keep_size, mode=mode) # parameters if keep_size: raise NotImplementedError(f'Do not support keep_size=True in {self.__class__.__name__}') if len(indices) != len(times): raise ValueError(f'The length of "indices" and "times" must be the same. ' f'However, we got {len(indices)} != {len(times)}.') self.num_times = len(times) self.spk_type = get_spk_type(spk_type, self.mode) # data about times and indices self.times = bm.asarray(times) self.indices = bm.asarray(indices, dtype=bm.int_) if need_sort: sort_idx = jnp.argsort(self.times.value) self.indices.value = self.indices[sort_idx] self.times.value = self.times[sort_idx] # variables self.reset_state(self.mode) def reset_state(self, batch_size=None, **kwargs): self.i = bm.Variable(bm.asarray(0)) self.spike = variable_(partial(jnp.zeros, dtype=self.spk_type), self.varshape, batch_size, axis_names=self.sharding, batch_axis_name=bm.sharding.BATCH_AXIS)
[docs] def update(self): # self.spike.value = bm.sharding.partition(bm.zeros_like(self.spike), self.spike.sharding) self.spike.value = bm.zeros_like(self.spike) bm.while_loop(self._body_fun, self._cond_fun, ()) return self.spike.value
def return_info(self): return self.spike # functions def _cond_fun(self): i = self.i.value return bm.logical_and(i < self.num_times, share['t'] >= self.times[i]) def _body_fun(self): i = self.i.value if isinstance(self.mode, bm.BatchingMode): self.spike[:, self.indices[i]] = True else: self.spike[self.indices[i]] = True self.i += 1
[docs] class PoissonGroup(NeuDyn): """Poisson Neuron Group. """ def __init__( self, size: Shape, freqs: Union[int, float, jax.Array, bm.Array, Callable], keep_size: bool = False, sharding: Optional[Sequence[str]] = None, spk_type: Optional[type] = None, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): super().__init__(size=size, sharding=sharding, name=name, keep_size=keep_size, mode=mode) # parameters self.freqs = parameter(freqs, self.num, allow_none=False) self.spk_type = get_spk_type(spk_type, self.mode) # variables self.reset_state(self.mode)
[docs] def update(self): spikes = bm.random.rand_like(self.spike.value) <= (self.freqs * share['dt'] / 1000.) spikes = bm.asarray(spikes, dtype=self.spk_type) # import jax # jax.debug.print('PoissonGroup: freqs = {f}, spikes = {s}', f=self.freqs, s=spikes) self.spike.value = spikes return spikes
def return_info(self): return self.spike def reset_state(self, batch_or_mode=None, **kwargs): self.spike = self.init_variable(partial(jnp.zeros, dtype=self.spk_type), batch_or_mode)