Source code for brainpy.dyn.rates.reservoir

# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Optional, Union, Callable, Tuple

import jax.numpy as jnp

import brainpy.math as bm
from brainpy import check
from brainpy.dnn.base import Layer
from brainpy.initialize import Normal, ZeroInit, Initializer, parameter, variable
from brainpy.tools import to_size
from brainpy.types import ArrayType

__all__ = [
    'Reservoir',
]


[docs] class Reservoir(Layer): r"""Reservoir node, a pool of leaky-integrator neurons with random recurrent connections [1]_. Parameters:: input_shape: int, tuple of int The input shape. num_out: int The number of reservoir nodes. Win_initializer: Initializer The initialization method for the feedforward connections. Wrec_initializer: Initializer The initialization method for the recurrent connections. b_initializer: optional, ArrayType, Initializer The initialization method for the bias. leaky_rate: float A float between 0 and 1. activation : str, callable, optional Reservoir activation function. - If a str, should be a :py:mod:`brainpy.math.activations` function name. - If a callable, should be an element-wise operator. activation_type : str - If "internal" (default), then leaky integration happens on states transformed by the activation function: .. math:: r[n+1] = (1 - \alpha) \cdot r[t] + \alpha \cdot f(W_{ff} \cdot u[n] + W_{fb} \cdot b[n] + W_{rec} \cdot r[t]) - If "external", then leaky integration happens on internal states of each neuron, stored in an ``internal_state`` parameter (:math:`x` in the equation below). A neuron internal state is the value of its state before applying the activation function :math:`f`: .. math:: x[n+1] &= (1 - \alpha) \cdot x[t] + \alpha \cdot f(W_{ff} \cdot u[n] + W_{rec} \cdot r[t] + W_{fb} \cdot b[n]) \\ r[n+1] &= f(x[n+1]) in_connectivity : float, optional Connectivity of input neurons, i.e. ratio of input neurons connected to reservoir neurons. Must be in [0, 1], by default 0.1 rec_connectivity : float, optional Connectivity of recurrent weights matrix, i.e. ratio of reservoir neurons connected to other reservoir neurons, including themselves. Must be in [0, 1], by default 0.1 comp_type: str The connectivity type, can be "dense" or "sparse", "jit". - ``"dense"`` means the connectivity matrix is a dense matrix. - ``"sparse"`` means the connectivity matrix is a CSR sparse matrix. spectral_radius : float, optional Spectral radius of recurrent weight matrix, by default None. noise_rec : float, optional Gain of noise applied to reservoir internal states, by default 0.0 noise_in : float, optional Gain of noise applied to feedforward signals, by default 0.0 noise_type : optional, str, callable Distribution of noise. Must be a random variable generator distribution (see :py:class:`brainpy.math.random.RandomState`), by default "normal". References:: .. [1] Lukoševičius, Mantas. "A practical guide to applying echo state networks." Neural networks: Tricks of the trade. Springer, Berlin, Heidelberg, 2012. 659-686. """ def __init__( self, input_shape: Union[int, Tuple[int]], num_out: int, leaky_rate: float = 0.3, activation: Union[str, Callable] = 'tanh', activation_type: str = 'internal', Win_initializer: Union[Initializer, Callable, ArrayType] = Normal(scale=0.1), Wrec_initializer: Union[Initializer, Callable, ArrayType] = Normal(scale=0.1), b_initializer: Optional[Union[Initializer, Callable, ArrayType]] = ZeroInit(), in_connectivity: float = 0.1, rec_connectivity: float = 0.1, comp_type: str = 'dense', spectral_radius: Optional[float] = None, noise_in: float = 0., noise_rec: float = 0., noise_type: str = 'normal', mode: Optional[bm.Mode] = None, name: Optional[str] = None ): super(Reservoir, self).__init__(mode=mode, name=name) # parameters input_shape = to_size(input_shape) if input_shape[0] is None: input_shape = input_shape[1:] self.input_shape = input_shape self.output_shape = input_shape[:-1] + (num_out,) self.num_unit = num_out assert num_out > 0, f'Must be a positive integer, but we got {num_out}' self.leaky_rate = leaky_rate check.is_float(leaky_rate, 'leaky_rate', 0., 1.) self.activation = getattr(bm, activation) if isinstance(activation, str) else activation check.is_callable(self.activation, allow_none=False) self.activation_type = activation_type check.is_string(activation_type, 'activation_type', ['internal', 'external']) check.is_float(spectral_radius, 'spectral_radius', allow_none=True) self.spectral_radius = spectral_radius # initializations check.is_initializer(Win_initializer, 'ff_initializer', allow_none=False) check.is_initializer(Wrec_initializer, 'rec_initializer', allow_none=False) check.is_initializer(b_initializer, 'bias_initializer', allow_none=True) self._Win_initializer = Win_initializer self._Wrec_initializer = Wrec_initializer self._b_initializer = b_initializer # connectivity check.is_float(in_connectivity, 'ff_connectivity', 0., 1.) check.is_float(rec_connectivity, 'rec_connectivity', 0., 1.) self.ff_connectivity = in_connectivity self.rec_connectivity = rec_connectivity check.is_string(comp_type, 'conn_type', ['dense', 'sparse', 'jit']) self.comp_type = comp_type # noises check.is_float(noise_in, 'noise_ff') check.is_float(noise_rec, 'noise_rec') self.noise_ff = noise_in self.noise_rec = noise_rec self.noise_type = noise_type check.is_string(noise_type, 'noise_type', ['normal', 'uniform']) # initialize feedforward weights weight_shape = (input_shape[-1], self.num_unit) self.Wff_shape = weight_shape self.Win = parameter(self._Win_initializer, weight_shape) if self.ff_connectivity < 1.: conn_mat = bm.random.random(weight_shape) > self.ff_connectivity self.Win[conn_mat] = 0. if self.comp_type == 'sparse' and self.ff_connectivity < 1.: self.ff_pres, self.ff_posts = jnp.where(jnp.logical_not(bm.as_jax(conn_mat))) self.Win = self.Win[self.ff_pres, self.ff_posts] if isinstance(self.mode, bm.TrainingMode): self.Win = bm.TrainVar(self.Win) # initialize recurrent weights recurrent_shape = (self.num_unit, self.num_unit) self.Wrec = parameter(self._Wrec_initializer, recurrent_shape) if self.rec_connectivity < 1.: conn_mat = bm.random.random(recurrent_shape) > self.rec_connectivity self.Wrec[conn_mat] = 0. if self.spectral_radius is not None: current_sr = max(abs(jnp.linalg.eig(bm.as_jax(self.Wrec))[0])) self.Wrec *= self.spectral_radius / current_sr if self.comp_type == 'sparse' and self.rec_connectivity < 1.: self.rec_pres, self.rec_posts = jnp.where(jnp.logical_not(bm.as_jax(conn_mat))) self.Wrec = self.Wrec[self.rec_pres, self.rec_posts] self.bias = parameter(self._b_initializer, (self.num_unit,)) if isinstance(self.mode, bm.TrainingMode): self.Wrec = bm.TrainVar(self.Wrec) self.bias = None if (self.bias is None) else bm.TrainVar(self.bias) # initialize state self.state = variable(jnp.zeros, self.mode, self.output_shape) def reset_state(self, batch_or_mode=None, **kwargs): self.state.value = variable(jnp.zeros, batch_or_mode, self.output_shape)
[docs] def update(self, x): """Feedforward output.""" # inputs x = bm.as_jax(x) if self.noise_ff > 0: x += self.noise_ff * bm.random.uniform(-1, 1, x.shape) if self.comp_type == 'sparse' and self.ff_connectivity < 1.: sparse = {'data': self.Win, 'index': (self.ff_pres, self.ff_posts), 'shape': self.Wff_shape} hidden = bm.sparse.seg_matmul(x, sparse) else: hidden = x @ self.Win # recurrent if self.comp_type == 'sparse' and self.rec_connectivity < 1.: sparse = {'data': self.Wrec, 'index': (self.rec_pres, self.rec_posts), 'shape': (self.num_unit, self.num_unit)} hidden += bm.sparse.seg_matmul(self.state, sparse) else: hidden += self.state @ self.Wrec if self.activation_type == 'internal': hidden = self.activation(hidden) if self.noise_rec > 0.: hidden += self.noise_rec * bm.random.uniform(-1, -1, self.state.shape) # new state/output state = (1 - self.leaky_rate) * self.state + self.leaky_rate * hidden if self.activation_type == 'external': state = self.activation(state) self.state.value = state return state