Source code for brainpy._src.initialize.decay_inits

# -*- coding: utf-8 -*-
import numpy as np

from jax import vmap, jit, numpy as jnp
from functools import partial

from brainpy import math as bm
from brainpy.tools import to_size, size2num
from .base import _IntraLayerInitializer

__all__ = [
  'GaussianDecay',
  'DOGDecay',
]


@jit
@partial(vmap, in_axes=(0, None, None))
def gaussian_decay_dist_cal1(i_value, post_values, sigma):
  dists = jnp.abs(i_value - post_values)
  exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2)
  return bm.asarray(exp_dists)


@jit
@partial(vmap, in_axes=(0, None, None, None))
def gaussian_decay_dist_cal2(i_value, post_values, value_sizes, sigma):
  dists = jnp.abs(i_value - post_values)
  dists = jnp.where(dists > (value_sizes / 2), value_sizes - dists, dists)
  exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2)
  return bm.asarray(exp_dists)


[docs] class GaussianDecay(_IntraLayerInitializer): r"""Builds a Gaussian connectivity pattern within a population of neurons, where the weights decay with gaussian function. Specifically, for any pair of neurons :math:`(i, j)`, the weight is computed as .. math:: w(i, j) = w_{max} \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2 }{2\sigma^2}) where :math:`v_k^i` is the $i$-th neuron's encoded value at dimension $k$. Parameters ---------- sigma : float Width of the Gaussian function. max_w : float The weight amplitude of the Gaussian function. min_w : float, None The minimum weight value below which synapses are not created (default: :math:`0.005 * max\_w`). include_self : bool Whether create the conn at the same position. encoding_values : optional, list, tuple, int, float The value ranges to encode for neurons at each axis. - If `values` is not provided, the neuron only encodes each positional information, i.e., :math:`(i, j, k, ...)`, where :math:`i, j, k` is the index in the high-dimensional space. - If `values` is a single tuple/list of int/float, neurons at each dimension will encode the same range of values. For example, `values=(0, np.pi)`, neurons at each dimension will encode a continuous value space `[0, np.pi]`. - If `values` is a tuple/list of list/tuple, it means the value space will be different for each dimension. For example, `values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi))`. periodic_boundary : bool Whether the neuron encode the value space with the periodic boundary. normalize : bool Whether normalize the connection probability. """ def __init__(self, sigma, max_w, min_w=None, encoding_values=None, periodic_boundary=False, include_self=True, normalize=False): super(GaussianDecay, self).__init__() self.sigma = sigma self.max_w = max_w self.min_w = max_w * 0.005 if min_w is None else min_w self.encoding_values = encoding_values self.periodic_boundary = periodic_boundary self.include_self = include_self self.normalize = normalize def __call__(self, shape, dtype=None): """Build the weights. Parameters ---------- shape : tuple of int, list of int, int The network shape. Note, this is not the weight shape. """ shape = to_size(shape) net_size = size2num(shape) # value ranges to encode if self.encoding_values is None: value_ranges = tuple([(0, s) for s in shape]) elif isinstance(self.encoding_values, (tuple, list)): if len(self.encoding_values) == 0: raise ValueError elif isinstance(self.encoding_values[0], (int, float)): assert len(self.encoding_values) == 2 assert self.encoding_values[0] < self.encoding_values[1] value_ranges = tuple([self.encoding_values for _ in shape]) elif isinstance(self.encoding_values[0], (tuple, list)): if len(self.encoding_values) != len(shape): raise ValueError(f'The network size has {len(shape)} dimensions, while ' f'the encoded values provided only has {len(self.encoding_values)}-D. ' f'Error in {str(self)}.') for v in self.encoding_values: assert isinstance(v[0], (int, float)) assert len(v) == 2 value_ranges = tuple(self.encoding_values) else: raise ValueError(f'Unsupported encoding values: {self.encoding_values}') else: raise ValueError(f'Unsupported encoding values: {self.encoding_values}') # values values = [np.linspace(vs[0], vs[1], n + 1)[:n] for vs, n in zip(value_ranges, shape)] post_values = np.stack([v.flatten() for v in np.meshgrid(*values)]) value_sizes = np.array([v[1] - v[0] for v in value_ranges]) if value_sizes.ndim < post_values.ndim: value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) # connectivity matrix i_value_list = np.zeros(shape=(net_size, len(shape), 1)) for i in range(net_size): list_index = i # values for node i i_coordinate = tuple() for s in shape[:-1]: i, pos = divmod(i, s) i_coordinate += (pos,) i_coordinate += (i,) i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) if i_value.ndim < post_values.ndim: i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) i_value_list[list_index] = i_value if self.periodic_boundary: conn_mat = gaussian_decay_dist_cal2(i_value_list, post_values, value_sizes, self.sigma) else: conn_mat = gaussian_decay_dist_cal1(i_value_list, post_values, self.sigma) if self.normalize: conn_mat /= conn_mat.max() if not self.include_self: bm.fill_diagonal(conn_mat, 0.) # connectivity weights conn_mat *= self.max_w conn_mat = bm.where(conn_mat < self.min_w, 0., conn_mat) return bm.asarray(conn_mat, dtype=dtype) def __repr__(self): name = self.__class__.__name__ bank = ' ' * len(name) return (f'{name}(sigma={self.sigma}, max_w={self.max_w}, min_w={self.min_w}, \n' f'{bank}periodic_boundary={self.periodic_boundary}, ' f'include_self={self.include_self}, ' f'normalize={self.normalize})')
@jit @partial(vmap, in_axes=(0, None, None, None, None, None, None, None)) def _dog_decay_pd(voxel_ids, values, post_values, value_sizes, max_w_p, sigma_p, max_w_n, sigma_n): i_value = [] for i in range(len(voxel_ids)): p_id = voxel_ids[i] # position id i_value.append(values[i][p_id]) i_value = bm.array(i_value) if i_value.ndim < post_values.ndim: i_value = bm.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) # distances dists = bm.abs(i_value - post_values) dists = bm.where(dists > value_sizes / 2, value_sizes - dists, dists) dists_exp_p = max_w_p * bm.exp(-(bm.linalg.norm(dists, axis=0) / sigma_p) ** 2 / 2) dists_exp_n = max_w_n * bm.exp(-(bm.linalg.norm(dists, axis=0) / sigma_n) ** 2 / 2) return dists_exp_p - dists_exp_n @jit @partial(vmap, in_axes=(0, None, None, None, None, None, None)) def _dog_decay(voxel_ids, values, post_values, max_w_p, sigma_p, max_w_n, sigma_n): i_value = [] for i in range(len(voxel_ids)): p_id = voxel_ids[i] # position id i_value.append(values[i][p_id]) i_value = bm.array(i_value) if i_value.ndim < post_values.ndim: i_value = bm.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) # distances dists = bm.abs(i_value - post_values) dists_exp_p = max_w_p * bm.exp(-(bm.linalg.norm(dists, axis=0) / sigma_p) ** 2 / 2) dists_exp_n = max_w_n * bm.exp(-(bm.linalg.norm(dists, axis=0) / sigma_n) ** 2 / 2) return dists_exp_p - dists_exp_n
[docs] class DOGDecay(_IntraLayerInitializer): r"""Builds a Difference-Of-Gaussian (dog) connectivity pattern within a population of neurons. Mathematically, for the given pair of neurons :math:`(i, j)`, the weight between them is computed as .. math:: w(i, j) = w_{max}^+ \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2}{2\sigma_+^2}) - w_{max}^- \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2}{2\sigma_-^2}) where weights smaller than :math:`0.005 * max(w_{max}, w_{min})` are not created and self-connections are avoided by default (parameter allow_self_connections). Parameters ---------- sigmas : tuple Widths of the positive and negative Gaussian functions. max_ws : tuple The weight amplitudes of the positive and negative Gaussian functions. min_w : float, None The minimum weight value below which synapses are not created (default: :math:`0.005 * min(max\_ws)`). include_self : bool Whether create the connections at the same position (self-connections). normalize : bool Whether normalize the connection probability . encoding_values : optional, list, tuple, int, float The value ranges to encode for neurons at each axis. - If `values` is not provided, the neuron only encodes each positional information, i.e., :math:`(i, j, k, ...)`, where :math:`i, j, k` is the index in the high-dimensional space. - If `values` is a single tuple/list of int/float, neurons at each dimension will encode the same range of values. For example, `values=(0, np.pi)`, neurons at each dimension will encode a continuous value space `[0, np.pi]`. - If `values` is a tuple/list of list/tuple, it means the value space will be different for each dimension. For example, `values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi))`. periodic_boundary : bool Whether the neuron encode the value space with the periodic boundary. """ def __init__(self, sigmas, max_ws, min_w=None, encoding_values=None, periodic_boundary=False, normalize=True, include_self=True): super(DOGDecay, self).__init__() self.sigma_p, self.sigma_n = sigmas self.max_w_p, self.max_w_n = max_ws self.min_w = 0.005 * min(self.max_w_p, self.max_w_n) if min_w is None else min_w self.normalize = normalize self.include_self = include_self self.encoding_values = encoding_values self.periodic_boundary = periodic_boundary def __call__(self, shape, dtype=None): """Build the weights. Parameters ---------- shape : tuple of int, list of int, int The network shape. Note, this is not the weight shape. """ shape = to_size(shape) # value ranges to encode if self.encoding_values is None: value_ranges = tuple([(0, s) for s in shape]) elif isinstance(self.encoding_values, (tuple, list)): if len(self.encoding_values) == 0: raise ValueError elif isinstance(self.encoding_values[0], (int, float)): assert len(self.encoding_values) == 2 assert self.encoding_values[0] < self.encoding_values[1] value_ranges = tuple([self.encoding_values for _ in shape]) elif isinstance(self.encoding_values[0], (tuple, list)): if len(self.encoding_values) != len(shape): raise ValueError(f'The network size has {len(shape)} dimensions, while ' f'the encoded values provided only has {len(self.encoding_values)}-D. ' f'Error in {str(self)}.') for v in self.encoding_values: assert isinstance(v[0], (int, float)) assert len(v) == 2 value_ranges = tuple(self.encoding_values) else: raise ValueError(f'Unsupported encoding values: {self.encoding_values}') else: raise ValueError(f'Unsupported encoding values: {self.encoding_values}') # values values = [np.linspace(vs[0], vs[1], n + 1)[:n] for vs, n in zip(value_ranges, shape)] post_values = np.stack([v.flatten() for v in np.meshgrid(*values)]) value_sizes = np.array([v[1] - v[0] for v in value_ranges]) if value_sizes.ndim < post_values.ndim: value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) voxel_ids = np.meshgrid(*[np.arange(s) for s in shape]) if np.ndim(voxel_ids[0]) > 1: voxel_ids = tuple(np.moveaxis(m, 0, 1).flatten() for m in voxel_ids) # connectivity matrix if self.periodic_boundary: conn_weights = _dog_decay_pd(voxel_ids, values, post_values, value_sizes, self.max_w_p, self.sigma_p, self.max_w_n, self.sigma_n) else: conn_weights = _dog_decay(voxel_ids, values, post_values, self.max_w_p, self.sigma_p, self.max_w_n, self.sigma_n) if not self.include_self: conn_weights = bm.asarray(conn_weights) bm.fill_diagonal(conn_weights, 0.) # connectivity weights conn_weights = bm.where(np.abs(conn_weights) < self.min_w, 0., conn_weights) return bm.asarray(conn_weights, dtype=dtype) def __repr__(self): name = self.__class__.__name__ bank = ' ' * len(name) return (f'{name}(sigmas={(self.sigma_p, self.sigma_n)}, ' f'max_ws={(self.max_w_p, self.max_w_n)}, min_w={self.min_w}, \n' f'{bank}periodic_boundary={self.periodic_boundary}, ' f'include_self={self.include_self}, ' f'normalize={self.normalize})')