# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Optional, Union, Callable, Tuple
import jax.numpy as jnp
import brainpy.math as bm
from brainpy import check
from brainpy.dnn.base import Layer
from brainpy.initialize import Normal, ZeroInit, Initializer, parameter, variable
from brainpy.tools import to_size
from brainpy.types import ArrayType
__all__ = [
'Reservoir',
]
[docs]
class Reservoir(Layer):
r"""Reservoir node, a pool of leaky-integrator neurons with random recurrent connections [1]_.
Parameters::
input_shape: int, tuple of int
The input shape.
num_out: int
The number of reservoir nodes.
Win_initializer: Initializer
The initialization method for the feedforward connections.
Wrec_initializer: Initializer
The initialization method for the recurrent connections.
b_initializer: optional, ArrayType, Initializer
The initialization method for the bias.
leaky_rate: float
A float between 0 and 1.
activation : str, callable, optional
Reservoir activation function.
- If a str, should be a :py:mod:`brainpy.math.activations` function name.
- If a callable, should be an element-wise operator.
activation_type : str
- If "internal" (default), then leaky integration happens on states transformed
by the activation function:
.. math::
r[n+1] = (1 - \alpha) \cdot r[t] +
\alpha \cdot f(W_{ff} \cdot u[n] + W_{fb} \cdot b[n] + W_{rec} \cdot r[t])
- If "external", then leaky integration happens on internal states of
each neuron, stored in an ``internal_state`` parameter (:math:`x` in
the equation below).
A neuron internal state is the value of its state before applying
the activation function :math:`f`:
.. math::
x[n+1] &= (1 - \alpha) \cdot x[t] +
\alpha \cdot f(W_{ff} \cdot u[n] + W_{rec} \cdot r[t] + W_{fb} \cdot b[n]) \\
r[n+1] &= f(x[n+1])
in_connectivity : float, optional
Connectivity of input neurons, i.e. ratio of input neurons connected
to reservoir neurons. Must be in [0, 1], by default 0.1
rec_connectivity : float, optional
Connectivity of recurrent weights matrix, i.e. ratio of reservoir
neurons connected to other reservoir neurons, including themselves.
Must be in [0, 1], by default 0.1
comp_type: str
The connectivity type, can be "dense" or "sparse", "jit".
- ``"dense"`` means the connectivity matrix is a dense matrix.
- ``"sparse"`` means the connectivity matrix is a CSR sparse matrix.
spectral_radius : float, optional
Spectral radius of recurrent weight matrix, by default None.
noise_rec : float, optional
Gain of noise applied to reservoir internal states, by default 0.0
noise_in : float, optional
Gain of noise applied to feedforward signals, by default 0.0
noise_type : optional, str, callable
Distribution of noise. Must be a random variable generator
distribution (see :py:class:`brainpy.math.random.RandomState`),
by default "normal".
References::
.. [1] Lukoševičius, Mantas. "A practical guide to applying echo state networks."
Neural networks: Tricks of the trade. Springer, Berlin, Heidelberg, 2012. 659-686.
"""
def __init__(
self,
input_shape: Union[int, Tuple[int]],
num_out: int,
leaky_rate: float = 0.3,
activation: Union[str, Callable] = 'tanh',
activation_type: str = 'internal',
Win_initializer: Union[Initializer, Callable, ArrayType] = Normal(scale=0.1),
Wrec_initializer: Union[Initializer, Callable, ArrayType] = Normal(scale=0.1),
b_initializer: Optional[Union[Initializer, Callable, ArrayType]] = ZeroInit(),
in_connectivity: float = 0.1,
rec_connectivity: float = 0.1,
comp_type: str = 'dense',
spectral_radius: Optional[float] = None,
noise_in: float = 0.,
noise_rec: float = 0.,
noise_type: str = 'normal',
mode: Optional[bm.Mode] = None,
name: Optional[str] = None
):
super(Reservoir, self).__init__(mode=mode, name=name)
# parameters
input_shape = to_size(input_shape)
if input_shape[0] is None:
input_shape = input_shape[1:]
self.input_shape = input_shape
self.output_shape = input_shape[:-1] + (num_out,)
self.num_unit = num_out
assert num_out > 0, f'Must be a positive integer, but we got {num_out}'
self.leaky_rate = leaky_rate
check.is_float(leaky_rate, 'leaky_rate', 0., 1.)
self.activation = getattr(bm, activation) if isinstance(activation, str) else activation
check.is_callable(self.activation, allow_none=False)
self.activation_type = activation_type
check.is_string(activation_type, 'activation_type', ['internal', 'external'])
check.is_float(spectral_radius, 'spectral_radius', allow_none=True)
self.spectral_radius = spectral_radius
# initializations
check.is_initializer(Win_initializer, 'ff_initializer', allow_none=False)
check.is_initializer(Wrec_initializer, 'rec_initializer', allow_none=False)
check.is_initializer(b_initializer, 'bias_initializer', allow_none=True)
self._Win_initializer = Win_initializer
self._Wrec_initializer = Wrec_initializer
self._b_initializer = b_initializer
# connectivity
check.is_float(in_connectivity, 'ff_connectivity', 0., 1.)
check.is_float(rec_connectivity, 'rec_connectivity', 0., 1.)
self.ff_connectivity = in_connectivity
self.rec_connectivity = rec_connectivity
check.is_string(comp_type, 'conn_type', ['dense', 'sparse', 'jit'])
self.comp_type = comp_type
# noises
check.is_float(noise_in, 'noise_ff')
check.is_float(noise_rec, 'noise_rec')
self.noise_ff = noise_in
self.noise_rec = noise_rec
self.noise_type = noise_type
check.is_string(noise_type, 'noise_type', ['normal', 'uniform'])
# initialize feedforward weights
weight_shape = (input_shape[-1], self.num_unit)
self.Wff_shape = weight_shape
self.Win = parameter(self._Win_initializer, weight_shape)
if self.ff_connectivity < 1.:
conn_mat = bm.random.random(weight_shape) > self.ff_connectivity
self.Win[conn_mat] = 0.
if self.comp_type == 'sparse' and self.ff_connectivity < 1.:
self.ff_pres, self.ff_posts = jnp.where(jnp.logical_not(bm.as_jax(conn_mat)))
self.Win = self.Win[self.ff_pres, self.ff_posts]
if isinstance(self.mode, bm.TrainingMode):
self.Win = bm.TrainVar(self.Win)
# initialize recurrent weights
recurrent_shape = (self.num_unit, self.num_unit)
self.Wrec = parameter(self._Wrec_initializer, recurrent_shape)
if self.rec_connectivity < 1.:
conn_mat = bm.random.random(recurrent_shape) > self.rec_connectivity
self.Wrec[conn_mat] = 0.
if self.spectral_radius is not None:
current_sr = max(abs(jnp.linalg.eig(bm.as_jax(self.Wrec))[0]))
self.Wrec *= self.spectral_radius / current_sr
if self.comp_type == 'sparse' and self.rec_connectivity < 1.:
self.rec_pres, self.rec_posts = jnp.where(jnp.logical_not(bm.as_jax(conn_mat)))
self.Wrec = self.Wrec[self.rec_pres, self.rec_posts]
self.bias = parameter(self._b_initializer, (self.num_unit,))
if isinstance(self.mode, bm.TrainingMode):
self.Wrec = bm.TrainVar(self.Wrec)
self.bias = None if (self.bias is None) else bm.TrainVar(self.bias)
# initialize state
self.state = variable(jnp.zeros, self.mode, self.output_shape)
def reset_state(self, batch_or_mode=None, **kwargs):
self.state.value = variable(jnp.zeros, batch_or_mode, self.output_shape)
[docs]
def update(self, x):
"""Feedforward output."""
# inputs
x = bm.as_jax(x)
if self.noise_ff > 0:
x += self.noise_ff * bm.random.uniform(-1, 1, x.shape)
if self.comp_type == 'sparse' and self.ff_connectivity < 1.:
sparse = {'data': self.Win,
'index': (self.ff_pres, self.ff_posts),
'shape': self.Wff_shape}
hidden = bm.sparse.seg_matmul(x, sparse)
else:
hidden = x @ self.Win
# recurrent
if self.comp_type == 'sparse' and self.rec_connectivity < 1.:
sparse = {'data': self.Wrec,
'index': (self.rec_pres, self.rec_posts),
'shape': (self.num_unit, self.num_unit)}
hidden += bm.sparse.seg_matmul(self.state, sparse)
else:
hidden += self.state @ self.Wrec
if self.activation_type == 'internal':
hidden = self.activation(hidden)
if self.noise_rec > 0.:
hidden += self.noise_rec * bm.random.uniform(-1, -1, self.state.shape)
# new state/output
state = (1 - self.leaky_rate) * self.state + self.leaky_rate * hidden
if self.activation_type == 'external':
state = self.activation(state)
self.state.value = state
return state