# -*- coding: utf-8 -*-
from typing import Union, Tuple, List
import jax
import jax.numpy as jnp
import numpy as np
from brainpy.errors import ConnectorError
from .base import *
__all__ = [
'One2One', 'one2one',
'All2All', 'all2all',
'GridFour', 'grid_four',
'GridEight', 'grid_eight',
'GridN',
]
[docs]
class One2One(TwoEndConnector):
"""Connect two neuron groups one by one. This means
The two neuron groups should have the same size.
"""
def __init__(self, *args, **kwargs):
super(One2One, self).__init__(*args, **kwargs)
def __call__(self, pre_size, post_size):
super(One2One, self).__call__(pre_size, post_size)
try:
assert self.pre_num == self.post_num
except AssertionError:
raise ConnectorError(f'One2One connection must be defined in two groups with the '
f'same size, but {self.pre_num} != {self.post_num}.')
return self
[docs]
def build_coo(self):
if self.pre_num != self.post_num:
raise ConnectorError(f'One2One connection must be defined in two groups with the '
f'same size, but {self.pre_num} != {self.post_num}.')
return np.arange(self.pre_num, dtype=get_idx_type()), np.arange(self.post_num, dtype=get_idx_type()),
[docs]
def build_csr(self):
if self.pre_num != self.post_num:
raise ConnectorError(f'One2One connection must be defined in two groups with the '
f'same size, but {self.pre_num} != {self.post_num}.')
ind = np.arange(self.pre_num)
indptr = np.arange(self.pre_num + 1)
return (np.asarray(ind, dtype=get_idx_type()), np.asarray(indptr, dtype=get_idx_type()))
[docs]
def build_mat(self):
if self.pre_num != self.post_num:
raise ConnectorError(f'One2One connection must be defined in two groups with the '
f'same size, but {self.pre_num} != {self.post_num}.')
mat = np.zeros((self.pre_num, self.post_num), dtype=MAT_DTYPE)
np.fill_diagonal(mat, True)
return mat
one2one = One2One()
[docs]
class All2All(TwoEndConnector):
"""Connect each neuron in first group to all neurons in the
post-synaptic neuron groups. It means this kind of conn
will create (num_pre x num_post) synapses.
"""
def __init__(self, *args, include_self: bool = True, **kwargs):
self.include_self = include_self
super(All2All, self).__init__(*args, **kwargs)
def __repr__(self):
return f'{self.__class__.__name__}(include_self={self.include_self})'
[docs]
def build_mat(self):
mat = np.ones((self.pre_num, self.post_num), dtype=MAT_DTYPE)
if not self.include_self:
np.fill_diagonal(mat, False)
return mat
all2all = All2All(include_self=True)
def get_size_length(sizes: Union[Tuple, List]):
if not isinstance(sizes, (tuple, list)):
raise TypeError
lengths = []
a = 1
for s in reversed(sizes):
lengths.insert(0, a)
a *= s
return np.asarray(lengths)
class GridConn(OneEndConnector):
def __init__(
self,
strides,
include_self: bool = False,
periodic_boundary: bool = False,
**kwargs
):
super(GridConn, self).__init__(**kwargs)
self.strides = strides
self.include_self = include_self
self.periodic_boundary = periodic_boundary
def __repr__(self):
return f'{self.__class__.__name__}(include_self={self.include_self}, periodic_boundary={self.periodic_boundary})'
def _format(self):
dim = len(self.post_size)
if self.pre_num != self.post_num:
raise ConnectorError(f'{self.__class__.__name__} is used to for connection within '
f'a same population. But we detect pre_num != post_num '
f'({self.pre_num} != {self.post_num}).')
# point indices
indices = jnp.meshgrid(*(jnp.arange(size) for size in self.post_size), indexing='ij')
indices = jnp.asarray(indices)
indices = indices.reshape(dim, self.post_num).T
lengths = jnp.asarray(self.post_size)
return lengths, dim, indices
def _get_strides(self, dim):
# increments
increments = np.asarray(np.meshgrid(*(self.strides for _ in range(dim)))).reshape(dim, -1).T
select_ids = self._select_stride(increments)
increments = jnp.asarray(increments[select_ids])
return increments
def _select_stride(self, stride: np.ndarray) -> np.ndarray:
raise NotImplementedError
def _select_dist(self, dist: jnp.ndarray) -> jnp.ndarray:
raise NotImplementedError
def build_mat(self):
sizes, _, indices = self._format()
@jax.vmap
def f_connect(pre_id):
# pre_id: R^(num_dim)
dist = jnp.abs(pre_id - indices)
if self.periodic_boundary:
dist = jnp.where(dist > sizes / 2, sizes - dist, dist)
return self._select_dist(dist)
return jnp.asarray(f_connect(indices), dtype=MAT_DTYPE)
def build_coo(self):
sizes, dim, indices = self._format()
strides = self._get_strides(dim)
@jax.vmap
def f_connect(pre_id):
# pre_id: R^(num_dim)
post_ids = pre_id + strides
if self.periodic_boundary:
post_ids = post_ids % sizes
else:
post_ids = jnp.where(post_ids < sizes, post_ids, -1)
size = len(post_ids)
pre_ids = jnp.repeat(pre_id, size).reshape(dim, size).T
return pre_ids, post_ids
pres, posts = f_connect(indices)
pres = pres.reshape(-1, dim)
posts = posts.reshape(-1, dim)
idx = jnp.nonzero(jnp.all(posts >= 0, axis=1))[0]
pres = pres[idx]
posts = posts[idx]
if dim == 1:
pres = pres.flatten()
posts = posts.flatten()
else:
strides = jnp.asarray(get_size_length(self.post_size))
pres = jnp.sum(pres * strides, axis=1)
posts = jnp.sum(posts * strides, axis=1)
return jnp.asarray(pres, dtype=get_idx_type()), jnp.asarray(posts, dtype=get_idx_type())
[docs]
class GridFour(GridConn):
"""The nearest four neighbors connection method.
Parameters
----------
periodic_boundary : bool
Whether the neuron encode the value space with the periodic boundary.
.. versionadded:: 2.2.3.2
include_self : bool
Whether create connection at the same position.
"""
def __init__(
self,
include_self: bool = False,
periodic_boundary: bool = False,
**kwargs
):
super(GridFour, self).__init__(strides=np.asarray([-1, 0, 1]),
include_self=include_self,
periodic_boundary=periodic_boundary,
**kwargs)
self.include_self = include_self
self.periodic_boundary = periodic_boundary
def _select_stride(self, stride: np.ndarray) -> np.ndarray:
temp = abs(stride).sum(axis=1)
return (temp <= 1) if self.include_self else (temp == 1)
def _select_dist(self, dist: jnp.ndarray) -> jnp.ndarray:
dist = jnp.linalg.norm(dist, axis=1)
return dist <= 1 if self.include_self else dist == 1
grid_four = GridFour()
[docs]
class GridN(GridConn):
"""The nearest (2*N+1) * (2*N+1) neighbors conn method.
Parameters
----------
N : int
Extend of the conn scope. For example:
When N=1,
[x x x]
[x I x]
[x x x]
When N=2,
[x x x x x]
[x x x x x]
[x x I x x]
[x x x x x]
[x x x x x]
include_self : bool
Whether create (i, i) conn ?
periodic_boundary: bool
Whether the neuron encode the value space with the periodic boundary.
.. versionadded:: 2.2.3.2
"""
def __init__(
self,
N: int = 1,
include_self: bool = False,
periodic_boundary: bool = False,
**kwargs
):
super(GridN, self).__init__(strides=np.arange(-N, N + 1, 1),
include_self=include_self,
periodic_boundary=periodic_boundary,
**kwargs)
self.N = N
def __repr__(self):
return (f'{self.__class__.__name__}(N={self.N}, '
f'include_self={self.include_self}, '
f'periodic_boundary={self.periodic_boundary})')
def _select_stride(self, stride: np.ndarray) -> np.ndarray:
return (np.ones(len(stride), dtype=bool)
if self.include_self else
(np.sum(np.abs(stride), axis=1) > 0))
def _select_dist(self, dist: jnp.ndarray) -> jnp.ndarray:
if self.include_self:
return jnp.all(dist <= self.N, axis=1)
else:
return jnp.logical_and(jnp.all(dist <= self.N, axis=1),
jnp.logical_not(jnp.all(dist == 0, axis=1)))
[docs]
class GridEight(GridN):
"""The nearest eight neighbors conn method.
Parameters
----------
include_self : bool
Whether create (i, i) conn ?
periodic_boundary: bool
Whether the neurons encode the value space with the periodic boundary.
.. versionadded:: 2.2.3.2
"""
def __init__(self,
include_self: bool = False,
periodic_boundary: bool = False,
**kwargs):
super(GridEight, self).__init__(N=1,
include_self=include_self,
periodic_boundary=periodic_boundary,
**kwargs)
grid_eight = GridEight()