# Source code for brainpy.dyn.layers.nvar

# -*- coding: utf-8 -*-

from itertools import combinations_with_replacement
from typing import Union, Sequence, List

import jax.numpy as jnp
import numpy as np

import brainpy.math as bm
from brainpy.dyn.base import DynamicalSystem
from brainpy.modes import Mode, NormalMode, BatchingMode, batching, check_mode
from brainpy.tools.checking import (check_integer, check_sequence)

__all__ = [
'NVAR'
]

def _comb(N, k):
r"""The number of combinations of N things taken k at a time.

.. math::

\frac{N!}{(N-k)! k!}

"""
if N > k:
val = 1
for j in range(min(k, N - k)):
val = (val * (N - j)) // (j + 1)
return val
elif N == k:
return 1
else:
return 0

[docs]class NVAR(DynamicalSystem):
"""Nonlinear vector auto-regression (NVAR) node.

This class has the following features:

- it supports batch size,
- it supports multiple orders,

Parameters
----------
delay: int
The number of delay step.
order: int, sequence of int
The nonlinear order.
stride: int
The stride to sample linear part vector in the delays.
constant: optional, float
The constant value.

References
----------
.. [1] Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation
reservoir computing. Nat Commun 12, 5564 (2021).
https://doi.org/10.1038/s41467-021-25801-2

"""

[docs]  def __init__(
self,
num_in,
delay: int,
order: Union[int, Sequence[int]] = None,
stride: int = 1,
constant: bool = False,
mode: Mode = batching,
name: str = None,
):
super(NVAR, self).__init__(mode=mode, name=name)
check_mode(self.mode, (BatchingMode, NormalMode), self.__class__.__name__)

# parameters
order = tuple() if order is None else order
if not isinstance(order, (tuple, list)):
order = (order,)
self.order = tuple(order)
check_sequence(order, 'order', allow_none=False)
for o in order:
check_integer(o, 'order', allow_none=False, min_bound=2)
check_integer(delay, 'delay', allow_none=False, min_bound=1)
check_integer(stride, 'stride', allow_none=False, min_bound=1)
assert isinstance(constant, bool), f'Must be an instance of boolean, but got {constant}.'
self.delay = delay
self.stride = stride
self.constant = constant
self.num_delay = 1 + (self.delay - 1) * self.stride
self.num_in = num_in

# delay variables
self.idx = bm.Variable(jnp.asarray([0]))
if isinstance(self.mode, BatchingMode):
batch_size = 1  # first initialize the state with batch size = 1
self.store = bm.Variable(jnp.zeros((self.num_delay, batch_size, self.num_in)), batch_axis=1)
else:
self.store = bm.Variable(jnp.zeros((self.num_delay, self.num_in)))

# linear dimension
self.linear_dim = self.delay * num_in
# For each monomial created in the non-linear part, indices
# of the n components involved, n being the order of the
# monomials. Precompute them to improve efficiency.
self.comb_ids = []
for order in self.order:
assert order >= 2, f'"order" must be a integer >= 2, while we got {order}.'
idx = np.array(list(combinations_with_replacement(np.arange(self.linear_dim), order)))
self.comb_ids.append(jnp.asarray(idx))
# number of non-linear components is (d + n - 1)! / (d - 1)! n!
# i.e. number of all unique monomials of order n made from the
# linear components.
self.nonlinear_dim = sum([len(ids) for ids in self.comb_ids])
# output dimension
self.num_out = int(self.linear_dim + self.nonlinear_dim)
if self.constant:
self.num_out += 1

def reset_state(self, batch_size=None):
"""Reset the node state which depends on batch size."""
self.idx[0] = 0
# To store the last inputs.
# Note, the batch axis is not in the first dimension, so we
# manually handle the state of NVAR, rather return it.
if batch_size is None:
self.store.value = jnp.zeros((self.num_delay, self.num_in))
else:
self.store.value = jnp.zeros((self.num_delay, batch_size, self.num_in))

def update(self, sha, x):
all_parts = []
select_ids = (self.idx[0] - jnp.arange(0, self.num_delay, self.stride)) % self.num_delay
# 1. Store the current input
self.store[self.idx[0]] = x

if isinstance(self.mode, BatchingMode):
# 2. Linear part:
# select all previous inputs, including the current, with strides
linear_parts = jnp.moveaxis(self.store[select_ids], 0, 1)  # (num_batch, num_time, num_feature)
linear_parts = jnp.reshape(linear_parts, (linear_parts.shape[0], -1))
# 3. constant
if self.constant:
constant = jnp.ones((linear_parts.shape[0], 1), dtype=x.dtype)
all_parts.append(constant)
all_parts.append(linear_parts)
# 3. Nonlinear part:
# select monomial terms and compute them
for ids in self.comb_ids:
all_parts.append(jnp.prod(linear_parts[:, ids], axis=2))

else:
# 2. Linear part:
# select all previous inputs, including the current, with strides
linear_parts = self.store[select_ids].flatten()  # (num_time x num_feature,)
# 3. constant
if self.constant:
constant = jnp.ones((1,), dtype=x.dtype)
all_parts.append(constant)
all_parts.append(linear_parts)
# 3. Nonlinear part:
# select monomial terms and compute them
for ids in self.comb_ids:
all_parts.append(jnp.prod(linear_parts[ids], axis=1))

# 4. Finally
self.idx.value = (self.idx + 1) % self.num_delay
return jnp.concatenate(all_parts, axis=-1)

def get_feature_names(self, for_plot=False) -> List[str]:
"""Get output feature names for transformation.

Parameters
----------
for_plot: bool
Use the feature names for plotting or not? (Default False)
"""
if for_plot:
linear_names = [f'x{i}_t' for i in range(self.num_in)]
else:
linear_names = [f'x{i}(t)' for i in range(self.num_in)]
for di in range(1, self.delay):
linear_names.extend([((f'x{i}_' + r'{t-%d}' % (di * self.stride))
if for_plot else f'x{i}(t-{di * self.stride})')
for i in range(self.num_in)])
nonlinear_names = []
for ids in self.comb_ids:
for id_ in np.asarray(ids):
uniques, counts = np.unique(id_, return_counts=True)
nonlinear_names.append(" ".join(
"%s^%d" % (linear_names[ind], exp) if (exp != 1) else linear_names[ind]
for ind, exp in zip(uniques, counts)
))
if for_plot:
all_names = [f'${n}$' for n in linear_names] + [f'${n}$' for n in nonlinear_names]
else:
all_names = linear_names + nonlinear_names
if self.constant:
all_names = ['1'] + all_names
return all_names