Source code for brainpy._src.measure.correlation

# -*- coding: utf-8 -*-


import numpy as onp
from jax import vmap, lax, numpy as jnp

from brainpy._src import math as bm
from brainpy.errors import UnsupportedError

__all__ = [
  'cross_correlation',
  'voltage_fluctuation',
  'matrix_correlation',
  'weighted_correlation',
  'functional_connectivity',
  # 'functional_connectivity_dynamics',
]


[docs] def cross_correlation(spikes, bin, dt=None, numpy=True, method='loop'): r"""Calculate cross correlation index between neurons. The coherence [1]_ between two neurons i and j is measured by their cross-correlation of spike trains at zero time lag within a time bin of :math:`\Delta t = \tau`. More specifically, suppose that a long time interval T is divided into small bins of :math:`\Delta t` and that two spike trains are given by :math:`X(l)=` 0 or 1, :math:`Y(l)=` 0 or 1, :math:`l=1,2, \ldots, K(T / K=\tau)`. Thus, we define a coherence measure for the pair as: .. math:: \kappa_{i j}(\tau)=\frac{\sum_{l=1}^{K} X(l) Y(l)} {\sqrt{\sum_{l=1}^{K} X(l) \sum_{l=1}^{K} Y(l)}} The population coherence measure :math:`\kappa(\tau)` is defined by the average of :math:`\kappa_{i j}(\tau)` over many pairs of neurons in the network. .. note:: To JIT compile this function, users should make ``bin``, ``dt``, ``numpy`` static. For example, ``partial(brainpy.measure.cross_correlation, bin=10, numpy=False)``. Parameters ---------- spikes : ndarray The history of spike states of the neuron group. bin : float, int The time bin to normalize spike states. dt : float, optional The time precision. numpy: bool Whether we use numpy array as the functional output. If ``False``, this function can be JIT compiled. method: str The method to calculate all pairs of cross correlation. Supports two kinds of methods: `loop` and `vmap`. `vmap` method needs much more memory. .. versionadded:: 2.2.3.4 Returns ------- cc_index : float The cross correlation value which represents the synchronization index. References ---------- .. [1] Wang, Xiao-Jing, and György Buzsáki. "Gamma oscillation by synaptic inhibition in a hippocampal interneuronal network model." Journal of neuroscience 16.20 (1996): 6402-6413. """ spikes = bm.as_numpy(spikes) if numpy else bm.as_jax(spikes) np = onp if numpy else jnp dt = bm.get_dt() if dt is None else dt bin_size = int(bin / dt) num_hist, num_neu = spikes.shape num_bin = int(onp.ceil(num_hist / bin_size)) if num_bin * bin_size != num_hist: spikes = np.append(spikes, np.zeros((num_bin * bin_size - num_hist, num_neu)), axis=0) states = spikes.T.reshape((num_neu, num_bin, bin_size)) states = jnp.asarray(np.sum(states, axis=2) > 0., dtype=jnp.float_) indices = jnp.tril_indices(num_neu, k=-1) if method == 'loop': def _f(i, j): sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j])) return lax.cond(sqrt_ij == 0., lambda _: 0., lambda _: jnp.sum(states[i] * states[j]) / sqrt_ij, None) res = bm.for_loop(_f, operands=indices) elif method == 'vmap': @vmap def _cc(i, j): sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j])) return lax.cond(sqrt_ij == 0., lambda _: 0., lambda _: jnp.sum(states[i] * states[j]) / sqrt_ij, None) res = _cc(*indices) else: raise UnsupportedError(f'Do not support {method}. We only support "loop" or "vmap".') return np.mean(np.asarray(res))
def _f_signal(signal): return jnp.mean(signal * signal) - jnp.mean(signal) ** 2
[docs] def voltage_fluctuation(potentials, numpy=True, method='loop'): r"""Calculate neuronal synchronization via voltage variance. The method comes from [1]_ [2]_ [3]_. First, average over the membrane potential :math:`V` .. math:: V(t) = \frac{1}{N} \sum_{i=1}^{N} V_i(t) The variance of the time fluctuations of :math:`V(t)` is .. math:: \sigma_V^2 = \left\langle \left[ V(t) \right]^2 \right\rangle_t - \left[ \left\langle V(t) \right\rangle_t \right]^2 where :math:`\left\langle \ldots \right\rangle_t = (1 / T_m) \int_0^{T_m} dt \, \ldots` denotes time-averaging over a large time, :math:`\tau_m`. After normalization of :math:`\sigma_V` to the average over the population of the single cell membrane potentials .. math:: \sigma_{V_i}^2 = \left\langle\left[ V_i(t) \right]^2 \right\rangle_t - \left[ \left\langle V_i(t) \right\rangle_t \right]^2 one defines a synchrony measure, :math:`\chi (N)`, for the activity of a system of :math:`N` neurons by: .. math:: \chi^2 \left( N \right) = \frac{\sigma_V^2}{ \frac{1}{N} \sum_{i=1}^N \sigma_{V_i}^2} .. [1] Golomb, D. and Rinzel J. (1993) Dynamics of globally coupled inhibitory neurons with heterogeneity. Phys. Rev. E 48:4810-4814. .. [2] Golomb D. and Rinzel J. (1994) Clustering in globally coupled inhibitory neurons. Physica D 72:259-282. .. [3] David Golomb (2007) Neuronal synchrony measures. Scholarpedia, 2(1):1347. Args: potentials: The membrane potential matrix of the neuron group. numpy: Whether we use numpy array as the functional output. If ``False``, this function can be JIT compiled. method: The method to calculate all pairs of cross correlation. Supports two kinds of methods: `loop` and `vmap`. `vmap` method will consume much more memory. .. versionadded:: 2.2.3.4 Returns: sync_index: The synchronization index. """ potentials = bm.as_jax(potentials) avg = jnp.mean(potentials, axis=1) avg_var = jnp.mean(avg * avg) - jnp.mean(avg) ** 2 if method == 'loop': _var = bm.for_loop(_f_signal, operands=jnp.moveaxis(potentials, 0, 1)) elif method == 'vmap': _var = vmap(_f_signal, in_axes=1)(potentials) else: raise UnsupportedError(f'Do not support {method}. We only support "loop" or "vmap".') var_mean = jnp.mean(_var) r = jnp.where(var_mean == 0., 1., avg_var / var_mean) return bm.as_numpy(r) if numpy else r
[docs] def matrix_correlation(x, y, numpy=True): """Pearson correlation of the lower triagonal of two matrices. The triangular matrix is offset by k = 1 in order to ignore the diagonal line Parameters ---------- x: ndarray First matrix. y: ndarray Second matrix numpy: bool Whether we use numpy array as the functional output. If ``False``, this function can be JIT compiled. Returns ------- coef: ndarray Correlation coefficient """ x = bm.as_numpy(x) if numpy else bm.as_device_array(x) y = bm.as_numpy(y) if numpy else bm.as_device_array(y) np = onp if numpy else jnp if x.ndim != 2: raise ValueError(f'Only support 2d array, but we got a array ' f'with the shape of {x.shape}') if y.ndim != 2: raise ValueError(f'Only support 2d array, but we got a array ' f'with the shape of {y.shape}') x = x[np.triu_indices_from(x, k=1)] y = y[np.triu_indices_from(y, k=1)] cc = np.corrcoef(x, y)[0, 1] return cc
[docs] def functional_connectivity(activities, numpy=True): """Functional connectivity matrix of timeseries activities. Parameters ---------- activities: ndarray The multidimensional array with the shape of ``(num_time, num_sample)``. numpy: bool Whether we use numpy array as the functional output. If ``False``, this function can be JIT compiled. Returns ------- connectivity_matrix: ndarray ``num_sample x num_sample`` functional connectivity matrix. """ activities = bm.as_numpy(activities) if numpy else bm.as_device_array(activities) np = onp if numpy else jnp if activities.ndim != 2: raise ValueError('Only support 2d array with shape of "(num_time, num_sample)". ' f'But we got a array with the shape of {activities.shape}') fc = np.corrcoef(activities.T) return np.nan_to_num(fc)
def functional_connectivity_dynamics(activities, window_size=30, step_size=5): """Computes functional connectivity dynamics (FCD) matrix. Parameters ---------- activities: ndarray The time series with shape of ``(num_time, num_sample)``. window_size: int Size of each rolling window in time steps, defaults to 30. step_size: int Step size between each rolling window, defaults to 5. Returns ------- fcd_matrix: ndarray FCD matrix. """ pass
[docs] def weighted_correlation(x, y, w, numpy=True): """Weighted Pearson correlation of two data series. Parameters ---------- x: ndarray The data series 1. y: ndarray The data series 2. w: ndarray Weight vector, must have same length as x and y. numpy: bool Whether we use numpy array as the functional output. If ``False``, this function can be JIT compiled. Returns ------- corr: ndarray Weighted correlation coefficient. """ x = bm.as_numpy(x) if numpy else bm.as_device_array(x) y = bm.as_numpy(y) if numpy else bm.as_device_array(y) w = bm.as_numpy(w) if numpy else bm.as_device_array(w) np = onp if numpy else jnp def _weighted_mean(x, w): """Weighted Mean""" return np.sum(x * w) / np.sum(w) def _weighted_cov(x, y, w): """Weighted Covariance""" return np.sum(w * (x - _weighted_mean(x, w)) * (y - _weighted_mean(y, w))) / np.sum(w) if x.ndim != 1: raise ValueError(f'Only support 1d array, but we got a array ' f'with the shape of {x.shape}') if y.ndim != 1: raise ValueError(f'Only support 1d array, but we got a array ' f'with the shape of {y.shape}') if w.ndim != 1: raise ValueError(f'Only support 1d array, but we got a array ' f'with the shape of {w.shape}') return _weighted_cov(x, y, w) / np.sqrt(_weighted_cov(x, x, w) * _weighted_cov(y, y, w))