Source code for brainpy._src.encoding.stateless_encoding
# -*- coding: utf-8 -*-
from typing import Union, Optional
import jax
import brainpy.math as bm
from brainpy import check
from brainpy.types import ArrayType
from .base import Encoder
__all__ = [
'PoissonEncoder',
]
[docs]class PoissonEncoder(Encoder):
r"""Encode the rate input as the Poisson spike train.
Given the input :math:`x`, the poisson encoder will output
spikes whose firing probability is :math:`x_{\text{normalize}}`, where
:math:`x_{\text{normalize}}` is normalized into ``[0, 1]`` according
to :math:`x_{\text{normalize}} = \frac{x-\text{min_val}}{\text{max_val} - \text{min_val}}`.
Parameters
----------
min_val: float
The minimal value in the given data `x`, used to the data normalization.
max_val: float
The maximum value in the given data `x`, used to the data normalization.
seed: int, ArrayType
The seed or key for random generation.
"""
[docs] def __init__(self,
min_val: Optional[float] = None,
max_val: Optional[float] = None,
seed: Union[int, ArrayType] = None):
super().__init__()
self.min_val = check.is_float(min_val, 'min_val', allow_none=True)
self.max_val = check.is_float(max_val, 'max_val', allow_none=True)
self.rng = bm.random.default_rng(seed)
def __call__(self, x: ArrayType, num_step: int = None):
"""
Parameters
----------
x: ArrayType
The rate input.
num_step: int
Encode rate values as spike trains in the given time length.
- If ``time_len=None``, encode the rate values at the current time step.
Users should repeatedly call it to encode `x` as a spike train.
- Else, given the ``x`` with shape ``(S, ...)``, the encoded
spike train is the array with shape ``(time_len, S, ...)``.
Returns
-------
out: ArrayType
The encoded spike train.
"""
with jax.ensure_compile_time_eval():
check.is_integer(num_step, 'time_len', min_bound=1, allow_none=True)
if not (self.min_val is None or self.max_val is None):
x = (x - self.min_val) / (self.max_val - self.min_val)
shape = x.shape if (num_step is None) else ((num_step,) + x.shape)
d = bm.as_jax(self.rng.rand(*shape)) < x
return d.astype(x.dtype)