Source code for brainpy._src.dyn.rates.nvar

# -*- coding: utf-8 -*-

from itertools import combinations_with_replacement
from typing import Union, Sequence, List, Optional

import jax.numpy as jnp
import numpy as np

import brainpy.math as bm
from brainpy import check
from brainpy._src.dnn.base import Layer

__all__ = [
  'NVAR'
]


def _comb(N, k):
  r"""The number of combinations of N things taken k at a time.

  .. math::

     \frac{N!}{(N-k)! k!}

  """
  if N > k:
    val = 1
    for j in range(min(k, N - k)):
      val = (val * (N - j)) // (j + 1)
    return val
  elif N == k:
    return 1
  else:
    return 0


[docs] class NVAR(Layer): """Nonlinear vector auto-regression (NVAR) node. This class has the following features: - it supports batch size, - it supports multiple orders, Parameters ---------- delay: int The number of delay step. order: int, sequence of int The nonlinear order. stride: int The stride to sample linear part vector in the delays. constant: optional, float The constant value. References ---------- .. [1] Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation reservoir computing. Nat Commun 12, 5564 (2021). https://doi.org/10.1038/s41467-021-25801-2 """ def __init__( self, num_in: int, delay: int, order: Optional[Union[int, Sequence[int]]] = None, stride: int = 1, constant: bool = False, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super(NVAR, self).__init__(mode=mode, name=name) # parameters order = tuple() if order is None else order if not isinstance(order, (tuple, list)): order = (order,) self.order = tuple(order) check.is_sequence(order, 'order', allow_none=False) for o in order: check.is_integer(o, 'order', allow_none=False, min_bound=2) check.is_integer(delay, 'delay', allow_none=False, min_bound=1) check.is_integer(stride, 'stride', allow_none=False, min_bound=1) assert isinstance(constant, bool), f'Must be an instance of boolean, but got {constant}.' self.delay = delay self.stride = stride self.constant = constant self.num_delay = 1 + (self.delay - 1) * self.stride self.num_in = num_in # delay variables self.idx = bm.Variable(jnp.asarray([0])) if isinstance(self.mode, bm.BatchingMode): batch_size = 1 # first initialize the state with batch size = 1 self.store = bm.Variable(jnp.zeros((self.num_delay, batch_size, self.num_in)), batch_axis=1) else: self.store = bm.Variable(jnp.zeros((self.num_delay, self.num_in))) # linear dimension self.linear_dim = self.delay * num_in # For each monomial created in the non-linear part, indices # of the n components involved, n being the order of the # monomials. Precompute them to improve efficiency. self.comb_ids = [] for order in self.order: assert order >= 2, f'"order" must be a integer >= 2, while we got {order}.' idx = np.array(list(combinations_with_replacement(np.arange(self.linear_dim), order))) self.comb_ids.append(jnp.asarray(idx)) # number of non-linear components is (d + n - 1)! / (d - 1)! n! # i.e. number of all unique monomials of order n made from the # linear components. self.nonlinear_dim = sum([len(ids) for ids in self.comb_ids]) # output dimension self.num_out = int(self.linear_dim + self.nonlinear_dim) if self.constant: self.num_out += 1
[docs] def reset_state(self, batch_or_mode=None, **kwargs): """Reset the node state which depends on batch size.""" self.idx[0] = 0 # To store the last inputs. # Note, the batch axis is not in the first dimension, so we # manually handle the state of NVAR, rather return it. if batch_or_mode is None: self.store.value = jnp.zeros((self.num_delay, self.num_in)) else: self.store.value = jnp.zeros((self.num_delay, batch_or_mode, self.num_in))
[docs] def update(self, x): all_parts = [] select_ids = (self.idx[0] - jnp.arange(0, self.num_delay, self.stride)) % self.num_delay # 1. Store the current input self.store[self.idx[0]] = x if isinstance(self.mode, bm.BatchingMode): # 2. Linear part: # select all previous inputs, including the current, with strides linear_parts = jnp.moveaxis(self.store[select_ids], 0, 1) # (num_batch, num_time, num_feature) linear_parts = jnp.reshape(linear_parts, (linear_parts.shape[0], -1)) # 3. constant if self.constant: constant = jnp.ones((linear_parts.shape[0], 1), dtype=x.dtype) all_parts.append(constant) all_parts.append(linear_parts) # 3. Nonlinear part: # select monomial terms and compute them for ids in self.comb_ids: all_parts.append(jnp.prod(linear_parts[:, ids], axis=2)) else: # 2. Linear part: # select all previous inputs, including the current, with strides linear_parts = self.store[select_ids].flatten() # (num_time x num_feature,) # 3. constant if self.constant: constant = jnp.ones((1,), dtype=x.dtype) all_parts.append(constant) all_parts.append(linear_parts) # 3. Nonlinear part: # select monomial terms and compute them for ids in self.comb_ids: all_parts.append(jnp.prod(linear_parts[ids], axis=1)) # 4. Finally self.idx.value = (self.idx + 1) % self.num_delay return jnp.concatenate(all_parts, axis=-1)
[docs] def get_feature_names(self, for_plot=False) -> List[str]: """Get output feature names for transformation. Parameters ---------- for_plot: bool Use the feature names for plotting or not? (Default False) """ if for_plot: linear_names = [f'x{i}_t' for i in range(self.num_in)] else: linear_names = [f'x{i}(t)' for i in range(self.num_in)] for di in range(1, self.delay): linear_names.extend([((f'x{i}_' + r'{t-%d}' % (di * self.stride)) if for_plot else f'x{i}(t-{di * self.stride})') for i in range(self.num_in)]) nonlinear_names = [] for ids in self.comb_ids: for id_ in np.asarray(ids): uniques, counts = np.unique(id_, return_counts=True) nonlinear_names.append(" ".join( "%s^%d" % (linear_names[ind], exp) if (exp != 1) else linear_names[ind] for ind, exp in zip(uniques, counts) )) if for_plot: all_names = [f'${n}$' for n in linear_names] + [f'${n}$' for n in nonlinear_names] else: all_names = linear_names + nonlinear_names if self.constant: all_names = ['1'] + all_names return all_names