# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from functools import partial
from typing import Union, Sequence, Any, Optional, Callable
import jax
import jax.numpy as jnp
from brainpy import math as bm
from brainpy.context import share
from brainpy.dyn.base import NeuDyn
from brainpy.dyn.utils import get_spk_type
from brainpy.initialize import parameter, variable_
from brainpy.mixin import ReturnInfo
from brainpy.types import Shape, ArrayType
__all__ = [
'InputGroup',
'OutputGroup',
'SpikeTimeGroup',
'PoissonGroup',
]
[docs]
class OutputGroup(NeuDyn):
"""Output neuron group for place holder.
Args:
size: int, tuple of int
keep_size: bool
mode: Mode
name: str
"""
def __init__(
self,
size: Union[int, Sequence[int]],
sharding: Any = None,
keep_size: bool = False,
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
):
super().__init__(name=name,
sharding=sharding,
size=size,
keep_size=keep_size,
mode=mode)
[docs]
def update(self, x):
return x
def return_info(self):
return ReturnInfo(self.varshape, self.sharding, self.mode, bm.zeros)
def reset_state(self, batch_size=None, **kwargs):
pass
[docs]
class SpikeTimeGroup(NeuDyn):
"""The input neuron group characterized by spikes emitting at given times.
>>> # Get 2 neurons, firing spikes at 10 ms and 20 ms.
>>> SpikeTimeGroup(2, times=[10, 20])
>>> # or
>>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms.
>>> SpikeTimeGroup(2, times=[10, 20], indices=[0, 0])
>>> # or
>>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms.
>>> SpikeTimeGroup(2, times=[10, 20, 30], indices=[0, 1, 0])
>>> # or
>>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire;
>>> # at 30 ms, neuron 1 fires.
>>> SpikeTimeGroup(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])
Parameters::
size : int, tuple, list
The neuron group geometry.
indices : list, tuple, ArrayType
The neuron indices at each time point to emit spikes.
times : list, tuple, ArrayType
The time points which generate the spikes.
name : str, optional
The name of the dynamic system.
"""
def __init__(
self,
size: Union[int, Sequence[int]],
indices: Union[Sequence, ArrayType],
times: Union[Sequence, ArrayType],
spk_type: Optional[type] = None,
name: Optional[str] = None,
sharding: Optional[Sequence[str]] = None,
keep_size: bool = False,
mode: Optional[bm.Mode] = None,
need_sort: bool = True,
):
super().__init__(size=size,
sharding=sharding,
name=name,
keep_size=keep_size,
mode=mode)
# parameters
if keep_size:
raise NotImplementedError(f'Do not support keep_size=True in {self.__class__.__name__}')
if len(indices) != len(times):
raise ValueError(f'The length of "indices" and "times" must be the same. '
f'However, we got {len(indices)} != {len(times)}.')
self.num_times = len(times)
self.spk_type = get_spk_type(spk_type, self.mode)
# data about times and indices
self.times = bm.asarray(times)
self.indices = bm.asarray(indices, dtype=bm.int_)
if need_sort:
sort_idx = jnp.argsort(self.times.value)
self.indices.value = self.indices[sort_idx]
self.times.value = self.times[sort_idx]
# variables
self.reset_state(self.mode)
def reset_state(self, batch_size=None, **kwargs):
self.i = bm.Variable(bm.asarray(0))
self.spike = variable_(partial(jnp.zeros, dtype=self.spk_type),
self.varshape,
batch_size,
axis_names=self.sharding,
batch_axis_name=bm.sharding.BATCH_AXIS)
[docs]
def update(self):
# self.spike.value = bm.sharding.partition(bm.zeros_like(self.spike), self.spike.sharding)
self.spike.value = bm.zeros_like(self.spike)
bm.while_loop(self._body_fun, self._cond_fun, ())
return self.spike.value
def return_info(self):
return self.spike
# functions
def _cond_fun(self):
i = self.i.value
return bm.logical_and(i < self.num_times, share['t'] >= self.times[i])
def _body_fun(self):
i = self.i.value
if isinstance(self.mode, bm.BatchingMode):
self.spike[:, self.indices[i]] = True
else:
self.spike[self.indices[i]] = True
self.i += 1
[docs]
class PoissonGroup(NeuDyn):
"""Poisson Neuron Group.
"""
def __init__(
self,
size: Shape,
freqs: Union[int, float, jax.Array, bm.Array, Callable],
keep_size: bool = False,
sharding: Optional[Sequence[str]] = None,
spk_type: Optional[type] = None,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
super().__init__(size=size,
sharding=sharding,
name=name,
keep_size=keep_size,
mode=mode)
# parameters
self.freqs = parameter(freqs, self.num, allow_none=False)
self.spk_type = get_spk_type(spk_type, self.mode)
# variables
self.reset_state(self.mode)
[docs]
def update(self):
spikes = bm.random.rand_like(self.spike.value) <= (self.freqs * share['dt'] / 1000.)
spikes = bm.asarray(spikes, dtype=self.spk_type)
# import jax
# jax.debug.print('PoissonGroup: freqs = {f}, spikes = {s}', f=self.freqs, s=spikes)
self.spike.value = spikes
return spikes
def return_info(self):
return self.spike
def reset_state(self, batch_or_mode=None, **kwargs):
self.spike = self.init_variable(partial(jnp.zeros, dtype=self.spk_type), batch_or_mode)