# Source code for brainpy._src.measure.correlation

# -*- coding: utf-8 -*-

import numpy as onp
from jax import vmap, lax, numpy as jnp

from brainpy._src import math as bm
from brainpy.errors import UnsupportedError

__all__ = [
'cross_correlation',
'voltage_fluctuation',
'matrix_correlation',
'weighted_correlation',
'functional_connectivity',
# 'functional_connectivity_dynamics',
]

[docs]
def cross_correlation(spikes, bin, dt=None, numpy=True, method='loop'):
r"""Calculate cross correlation index between neurons.

The coherence [1]_ between two neurons i and j is measured by their
cross-correlation of spike trains at zero time lag within a time bin
of :math:\Delta t = \tau. More specifically, suppose that a long
time interval T is divided into small bins of :math:\Delta t and
that two spike trains are given by :math:X(l)= 0 or 1, :math:Y(l)= 0
or 1, :math:l=1,2, \ldots, K(T / K=\tau). Thus, we define a coherence
measure for the pair as:

.. math::

\kappa_{i j}(\tau)=\frac{\sum_{l=1}^{K} X(l) Y(l)}
{\sqrt{\sum_{l=1}^{K} X(l) \sum_{l=1}^{K} Y(l)}}

The population coherence measure :math:\kappa(\tau) is defined by the
average of :math:\kappa_{i j}(\tau) over many pairs of neurons in the
network.

.. note::
To JIT compile this function, users should make bin, dt, numpy static.
For example, partial(brainpy.measure.cross_correlation, bin=10, numpy=False).

Parameters
----------
spikes : ndarray
The history of spike states of the neuron group.
bin : float, int
The time bin to normalize spike states.
dt : float, optional
The time precision.
numpy: bool
Whether we use numpy array as the functional output.
If False, this function can be JIT compiled.
method: str
The method to calculate all pairs of cross correlation.
Supports two kinds of methods: loop and vmap.
vmap method needs much more memory.

.. versionadded:: 2.2.3.4

Returns
-------
cc_index : float
The cross correlation value which represents the synchronization index.

References
----------
.. [1] Wang, Xiao-Jing, and György Buzsáki. "Gamma oscillation by synaptic
inhibition in a hippocampal interneuronal network model." Journal of
neuroscience 16.20 (1996): 6402-6413.
"""
spikes = bm.as_numpy(spikes) if numpy else bm.as_jax(spikes)
np = onp if numpy else jnp
dt = bm.get_dt() if dt is None else dt
bin_size = int(bin / dt)
num_hist, num_neu = spikes.shape
num_bin = int(onp.ceil(num_hist / bin_size))
if num_bin * bin_size != num_hist:
spikes = np.append(spikes, np.zeros((num_bin * bin_size - num_hist, num_neu)), axis=0)
states = spikes.T.reshape((num_neu, num_bin, bin_size))
states = jnp.asarray(np.sum(states, axis=2) > 0., dtype=jnp.float_)
indices = jnp.tril_indices(num_neu, k=-1)

if method == 'loop':
def _f(i, j):
sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j]))
return lax.cond(sqrt_ij == 0.,
lambda _: 0.,
lambda _: jnp.sum(states[i] * states[j]) / sqrt_ij,
None)
res = bm.for_loop(_f, operands=indices)

elif method == 'vmap':
@vmap
def _cc(i, j):
sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j]))
return lax.cond(sqrt_ij == 0.,
lambda _: 0.,
lambda _: jnp.sum(states[i] * states[j]) / sqrt_ij,
None)

res = _cc(*indices)
else:
raise UnsupportedError(f'Do not support {method}. We only support "loop" or "vmap".')

return np.mean(np.asarray(res))

def _f_signal(signal):
return jnp.mean(signal * signal) - jnp.mean(signal) ** 2

[docs]
def voltage_fluctuation(potentials, numpy=True, method='loop'):
r"""Calculate neuronal synchronization via voltage variance.

The method comes from [1]_ [2]_ [3]_.

First, average over the membrane potential :math:V

.. math::

V(t) = \frac{1}{N} \sum_{i=1}^{N} V_i(t)

The variance of the time fluctuations of :math:V(t) is

.. math::

\sigma_V^2 = \left\langle \left[ V(t) \right]^2 \right\rangle_t -
\left[ \left\langle V(t) \right\rangle_t \right]^2

where :math:\left\langle \ldots \right\rangle_t = (1 / T_m) \int_0^{T_m} dt \, \ldots
denotes time-averaging over a large time, :math:\tau_m. After normalization
of :math:\sigma_V to the average over the population of the single cell
membrane potentials

.. math::

\sigma_{V_i}^2 = \left\langle\left[ V_i(t) \right]^2 \right\rangle_t -
\left[ \left\langle V_i(t) \right\rangle_t \right]^2

one defines a synchrony measure, :math:\chi (N), for the activity of a system
of :math:N neurons by:

.. math::

\chi^2 \left( N \right) = \frac{\sigma_V^2}{ \frac{1}{N} \sum_{i=1}^N
\sigma_{V_i}^2}

.. [1] Golomb, D. and Rinzel J. (1993) Dynamics of globally coupled
inhibitory neurons with heterogeneity. Phys. Rev. E 48:4810-4814.
.. [2] Golomb D. and Rinzel J. (1994) Clustering in globally coupled
inhibitory neurons. Physica D 72:259-282.
.. [3] David Golomb (2007) Neuronal synchrony measures. Scholarpedia, 2(1):1347.

Args:
potentials: The membrane potential matrix of the neuron group.
numpy: Whether we use numpy array as the functional output. If False, this function can be JIT compiled.
method: The method to calculate all pairs of cross correlation.
Supports two kinds of methods: loop and vmap.
vmap method will consume much more memory.

.. versionadded:: 2.2.3.4

Returns:
sync_index: The synchronization index.
"""

potentials = bm.as_jax(potentials)
avg = jnp.mean(potentials, axis=1)
avg_var = jnp.mean(avg * avg) - jnp.mean(avg) ** 2

if method == 'loop':
_var = bm.for_loop(_f_signal, operands=jnp.moveaxis(potentials, 0, 1))

elif method == 'vmap':
_var = vmap(_f_signal, in_axes=1)(potentials)
else:
raise UnsupportedError(f'Do not support {method}. We only support "loop" or "vmap".')

var_mean = jnp.mean(_var)
r = jnp.where(var_mean == 0., 1., avg_var / var_mean)
return bm.as_numpy(r) if numpy else r

[docs]
def matrix_correlation(x, y, numpy=True):
"""Pearson correlation of the lower triagonal of two matrices.

The triangular matrix is offset by k = 1 in order to ignore the diagonal line

Parameters
----------
x: ndarray
First matrix.
y: ndarray
Second matrix
numpy: bool
Whether we use numpy array as the functional output.
If False, this function can be JIT compiled.

Returns
-------
coef: ndarray
Correlation coefficient
"""

x = bm.as_numpy(x) if numpy else bm.as_device_array(x)
y = bm.as_numpy(y) if numpy else bm.as_device_array(y)
np = onp if numpy else jnp
if x.ndim != 2:
raise ValueError(f'Only support 2d array, but we got a array '
f'with the shape of {x.shape}')
if y.ndim != 2:
raise ValueError(f'Only support 2d array, but we got a array '
f'with the shape of {y.shape}')
x = x[np.triu_indices_from(x, k=1)]
y = y[np.triu_indices_from(y, k=1)]
cc = np.corrcoef(x, y)[0, 1]
return cc

[docs]
def functional_connectivity(activities, numpy=True):
"""Functional connectivity matrix of timeseries activities.

Parameters
----------
activities: ndarray
The multidimensional array with the shape of (num_time, num_sample).
numpy: bool
Whether we use numpy array as the functional output.
If False, this function can be JIT compiled.

Returns
-------
connectivity_matrix: ndarray
num_sample x num_sample functional connectivity matrix.
"""
activities = bm.as_numpy(activities) if numpy else bm.as_device_array(activities)
np = onp if numpy else jnp
if activities.ndim != 2:
raise ValueError('Only support 2d array with shape of "(num_time, num_sample)". '
f'But we got a array with the shape of {activities.shape}')
fc = np.corrcoef(activities.T)
return np.nan_to_num(fc)

def functional_connectivity_dynamics(activities, window_size=30, step_size=5):
"""Computes functional connectivity dynamics (FCD) matrix.

Parameters
----------
activities: ndarray
The time series with shape of (num_time, num_sample).
window_size: int
Size of each rolling window in time steps, defaults to 30.
step_size: int
Step size between each rolling window, defaults to 5.

Returns
-------
fcd_matrix: ndarray
FCD matrix.
"""
pass

[docs]
def weighted_correlation(x, y, w, numpy=True):
"""Weighted Pearson correlation of two data series.

Parameters
----------
x: ndarray
The data series 1.
y: ndarray
The data series 2.
w: ndarray
Weight vector, must have same length as x and y.
numpy: bool
Whether we use numpy array as the functional output.
If False, this function can be JIT compiled.

Returns
-------
corr: ndarray
Weighted correlation coefficient.
"""
x = bm.as_numpy(x) if numpy else bm.as_device_array(x)
y = bm.as_numpy(y) if numpy else bm.as_device_array(y)
w = bm.as_numpy(w) if numpy else bm.as_device_array(w)
np = onp if numpy else jnp

def _weighted_mean(x, w):
"""Weighted Mean"""
return np.sum(x * w) / np.sum(w)

def _weighted_cov(x, y, w):
"""Weighted Covariance"""
return np.sum(w * (x - _weighted_mean(x, w)) * (y - _weighted_mean(y, w))) / np.sum(w)

if x.ndim != 1:
raise ValueError(f'Only support 1d array, but we got a array '
f'with the shape of {x.shape}')
if y.ndim != 1:
raise ValueError(f'Only support 1d array, but we got a array '
f'with the shape of {y.shape}')
if w.ndim != 1:
raise ValueError(f'Only support 1d array, but we got a array '
f'with the shape of {w.shape}')
return _weighted_cov(x, y, w) / np.sqrt(_weighted_cov(x, x, w) * _weighted_cov(y, y, w))