# -*- coding: utf-8 -*-
import abc
from typing import Union, List, Tuple
import jax.numpy as jnp
import numpy as onp
from brainpy import tools, math as bm
from brainpy.errors import ConnectorError
import textwrap
__all__ = [
# the connection types
'CONN_MAT',
'PRE_IDS', 'POST_IDS',
'PRE2POST', 'POST2PRE',
'PRE2SYN', 'POST2SYN',
'SUPPORTED_SYN_STRUCTURE',
# the connection dtypes
'set_default_dtype', 'MAT_DTYPE', 'IDX_DTYPE', 'get_idx_type',
# brainpy_object class
'Connector', 'TwoEndConnector', 'OneEndConnector',
# methods
'mat2coo', 'mat2csc', 'mat2csr',
'csr2csc', 'csr2mat', 'csr2coo',
'coo2csr', 'coo2csc', 'coo2mat',
'coo2mat_num', 'mat2mat_num',
# visualize
'visualizeMat',
]
CONN_MAT = 'conn_mat'
PRE_IDS = 'pre_ids'
POST_IDS = 'post_ids'
PRE2POST = 'pre2post'
POST2PRE = 'post2pre'
PRE2SYN = 'pre2syn'
POST2SYN = 'post2syn'
PRE_SLICE = 'pre_slice'
POST_SLICE = 'post_slice'
COO = 'coo'
CSR = 'csr'
CSC = 'csc'
SUPPORTED_SYN_STRUCTURE = [CONN_MAT,
PRE_IDS, POST_IDS,
PRE2POST, POST2PRE,
PRE2SYN, POST2SYN,
PRE_SLICE, POST_SLICE,
COO, CSR, CSC]
MAT_DTYPE = jnp.bool_
IDX_DTYPE = jnp.int32
def get_idx_type():
return IDX_DTYPE
[docs]
def set_default_dtype(mat_dtype=None, idx_dtype=None):
"""Set the default dtype.
Use this method, you can set the default dtype for connetion matrix and
connection index.
For examples:
>>> import numpy as np
>>> import brainpy as bp
>>>
>>> conn = bp.conn.GridFour()(4, 4)
>>> conn.require('conn_mat')
Array([[False, True, False, False],
[ True, False, True, False],
[False, True, False, True],
[False, False, True, False]], dtype=bool)
>>> bp.connect.set_default_dtype(mat_dtype=np.float32)
>>> conn = bp.conn.GridFour()(4, 4)
>>> conn.require('conn_mat')
Array([[0., 1., 0., 0.],
[1., 0., 1., 0.],
[0., 1., 0., 1.],
[0., 0., 1., 0.]], dtype=float32)
Parameters
----------
mat_dtype : type
The default dtype for connection matrix.
idx_dtype : type
The default dtype for connection index.
"""
if mat_dtype is not None:
global MAT_DTYPE
MAT_DTYPE = mat_dtype
if idx_dtype is not None:
global IDX_DTYPE
IDX_DTYPE = idx_dtype
[docs]
class Connector(abc.ABC):
"""Base Synaptic Connector Class."""
pass
[docs]
class TwoEndConnector(Connector):
"""Synaptic connector to build connections between two neuron groups.
If users want to customize their `Connector`, there are two ways:
1. Implementing ``build_conn(self)`` function, which returns one of
the connection data ``csr`` (CSR sparse data, a tuple of <post_ids, inptr>),
``coo`` (COO sparse data, a tuple of <pre_ids, post_ids>), or ``mat``
(a binary connection matrix). For instance,
.. code-block:: python
import brainpy as bp
class MyConnector(bp.conn.TwoEndConnector):
def build_conn(self):
return dict(csr=, mat=, coo=)
2. Implementing functions ``build_mat()``, ``build_csr()``, and
``build_coo()``. Users can provide all three functions, or one of them.
.. code-block:: python
import brainpy as bp
class MyConnector(bp.conn.TwoEndConnector):
def build_mat(self, ):
return conn_matrix
def build_csr(self, ):
return post_ids, inptr
def build_coo(self, ):
return pre_ids, post_ids
"""
def __init__(
self,
pre: Union[int, Tuple[int, ...]] = None,
post: Union[int, Tuple[int, ...]] = None,
):
self.pre_size = None
self.post_size = None
self.pre_num = None
self.post_num = None
if pre is not None:
if isinstance(pre, int):
pre = (pre,)
else:
pre = tuple(pre)
self.pre_size = pre
self.pre_num = tools.size2num(self.pre_size)
if post is not None:
if isinstance(post, int):
post = (post,)
else:
post = tuple(post)
self.post_size = post
self.post_num = tools.size2num(self.post_size)
def __repr__(self):
return self.__class__.__name__
def __call__(self, pre_size, post_size):
"""Create the concrete connections between two end objects.
Parameters
----------
pre_size : int, tuple of int, list of int
The size of the pre-synaptic group.
post_size : int, tuple of int, list of int
The size of the post-synaptic group.
Returns
-------
conn : TwoEndConnector
Return the self.
"""
if isinstance(pre_size, int):
pre_size = (pre_size,)
else:
pre_size = tuple(pre_size)
if isinstance(post_size, int):
post_size = (post_size,)
else:
post_size = tuple(post_size)
self.pre_size, self.post_size = pre_size, post_size
self.pre_num = tools.size2num(self.pre_size)
self.post_num = tools.size2num(self.post_size)
return self
def _reset_conn(self, pre_size, post_size):
"""Reset connection attributes.
Parameters
----------
pre_size : int, tuple of int, list of int
The size of the pre-synaptic group.
post_size : int, tuple of int, list of int
The size of the post-synaptic group.
"""
self.__call__(pre_size, post_size)
@property
def is_version2_style(self):
if ((hasattr(self.build_coo, 'not_customized') and self.build_coo.not_customized) and
(hasattr(self.build_csr, 'not_customized') and self.build_csr.not_customized) and
(hasattr(self.build_mat, 'not_customized') and self.build_mat.not_customized)):
return False
else:
return True
def _check(self, structures: Union[Tuple, List, str]):
# check synaptic structures
if isinstance(structures, str):
structures = [structures]
if structures is None or len(structures) == 0:
raise ConnectorError('No synaptic structure is received.')
for n in structures:
if n not in SUPPORTED_SYN_STRUCTURE:
raise ConnectorError(f'Unknown synapse structure "{n}". '
f'Only {SUPPORTED_SYN_STRUCTURE} is supported.')
def _return_by_mat(self, structures, mat, all_data: dict):
assert mat.ndim == 2
if (CONN_MAT in structures) and (CONN_MAT not in all_data):
all_data[CONN_MAT] = bm.as_jax(mat, dtype=MAT_DTYPE)
if len([s for s in structures
if s not in [CONN_MAT]]) > 0:
ij = mat2coo(mat)
self._return_by_coo(structures, coo=ij, all_data=all_data)
def _return_by_csr(self, structures, csr: tuple, all_data: dict):
indices, indptr = csr
np = onp if isinstance(indices, onp.ndarray) else bm
assert self.pre_num == indptr.size - 1
if (CONN_MAT in structures) and (CONN_MAT not in all_data):
conn_mat = csr2mat((indices, indptr), self.pre_num, self.post_num)
all_data[CONN_MAT] = bm.as_jax(conn_mat, dtype=MAT_DTYPE)
if (PRE_IDS in structures) and (PRE_IDS not in all_data):
pre_ids = np.repeat(np.arange(self.pre_num), np.diff(indptr))
all_data[PRE_IDS] = bm.as_jax(pre_ids, dtype=get_idx_type())
if (POST_IDS in structures) and (POST_IDS not in all_data):
all_data[POST_IDS] = bm.as_jax(indices, dtype=get_idx_type())
if (COO in structures) and (COO not in all_data):
pre_ids = np.repeat(np.arange(self.pre_num), np.diff(indptr))
all_data[COO] = (bm.as_jax(pre_ids, dtype=get_idx_type()),
bm.as_jax(indices, dtype=get_idx_type()))
if (PRE2POST in structures) and (PRE2POST not in all_data):
all_data[PRE2POST] = (bm.as_jax(indices, dtype=get_idx_type()),
bm.as_jax(indptr, dtype=get_idx_type()))
if (CSR in structures) and (CSR not in all_data):
all_data[CSR] = (bm.as_jax(indices, dtype=get_idx_type()),
bm.as_jax(indptr, dtype=get_idx_type()))
if (POST2PRE in structures) and (POST2PRE not in all_data):
indc, indptrc = csr2csc((indices, indptr), self.post_num)
all_data[POST2PRE] = (bm.as_jax(indc, dtype=get_idx_type()),
bm.as_jax(indptrc, dtype=get_idx_type()))
if (CSC in structures) and (CSC not in all_data):
indc, indptrc = csr2csc((indices, indptr), self.post_num)
all_data[CSC] = (bm.as_jax(indc, dtype=get_idx_type()),
bm.as_jax(indptrc, dtype=get_idx_type()))
if (PRE2SYN in structures) and (PRE2SYN not in all_data):
syn_seq = np.arange(indices.size, dtype=get_idx_type())
all_data[PRE2SYN] = (bm.as_jax(syn_seq, dtype=get_idx_type()),
bm.as_jax(indptr, dtype=get_idx_type()))
if (POST2SYN in structures) and (POST2SYN not in all_data):
syn_seq = np.arange(indices.size, dtype=get_idx_type())
_, indptrc, syn_seqc = csr2csc((indices, indptr), self.post_num, syn_seq)
all_data[POST2SYN] = (bm.as_jax(syn_seqc, dtype=get_idx_type()),
bm.as_jax(indptrc, dtype=get_idx_type()))
def _return_by_coo(self, structures, coo: tuple, all_data: dict):
pre_ids, post_ids = coo
if (CONN_MAT in structures) and (CONN_MAT not in all_data):
all_data[CONN_MAT] = bm.as_jax(coo2mat(coo, self.pre_num, self.post_num), dtype=MAT_DTYPE)
if (PRE_IDS in structures) and (PRE_IDS not in all_data):
all_data[PRE_IDS] = bm.as_jax(pre_ids, dtype=get_idx_type())
if (POST_IDS in structures) and (POST_IDS not in all_data):
all_data[POST_IDS] = bm.as_jax(post_ids, dtype=get_idx_type())
if (COO in structures) and (COO not in all_data):
all_data[COO] = (bm.as_jax(pre_ids, dtype=get_idx_type()),
bm.as_jax(post_ids, dtype=get_idx_type()))
if CSC in structures and CSC not in all_data:
csc = coo2csc(coo, self.post_num)
all_data[CSC] = (bm.as_jax(csc[0], dtype=get_idx_type()),
bm.as_jax(csc[1], dtype=get_idx_type()))
if POST2PRE in structures and POST2PRE not in all_data:
csc = coo2csc(coo, self.post_num)
all_data[POST2PRE] = (bm.as_jax(csc[0], dtype=get_idx_type()),
bm.as_jax(csc[1], dtype=get_idx_type()))
if (len([s for s in structures
if s not in [CONN_MAT, PRE_IDS, POST_IDS,
COO, CSC, POST2PRE]]) > 0):
csr = coo2csr(coo, self.pre_num)
self._return_by_csr(structures, csr=csr, all_data=all_data)
def _make_returns(self, structures, conn_data):
"""Make the desired synaptic structures and return them.
"""
csr = None
mat = None
coo = None
if isinstance(conn_data, dict):
csr = conn_data.get('csr', None)
mat = conn_data.get('mat', None)
coo = conn_data.get('coo', None) or conn_data.get('ij', None)
elif isinstance(conn_data, tuple):
if conn_data[0] == 'csr':
csr = conn_data[1]
elif conn_data[0] == 'mat':
mat = conn_data[1]
elif conn_data[0] in ['coo', 'ij']:
coo = conn_data[1]
else:
raise ConnectorError(f'Must provide one of "csr", "mat" or "coo". Got "{conn_data[0]}" instead.')
else:
raise ConnectorError('Unknown type')
# checking
if (csr is None) and (mat is None) and (coo is None):
raise ConnectorError('Must provide one of "csr", "mat" or "coo".')
structures = (structures,) if isinstance(structures, str) else structures
assert isinstance(structures, (tuple, list))
all_data = dict()
# "csr" structure
if csr is not None:
if (PRE2POST in structures) and (PRE2POST not in all_data):
all_data[PRE2POST] = (bm.as_jax(csr[0], dtype=get_idx_type()),
bm.as_jax(csr[1], dtype=get_idx_type()))
self._return_by_csr(structures, csr=csr, all_data=all_data)
# "mat" structure
if mat is not None:
assert mat.ndim == 2
if (CONN_MAT in structures) and (CONN_MAT not in all_data):
all_data[CONN_MAT] = bm.as_jax(mat, dtype=MAT_DTYPE)
self._return_by_mat(structures, mat=mat, all_data=all_data)
# "coo" structure
if coo is not None:
if (PRE_IDS in structures) and (PRE_IDS not in structures):
all_data[PRE_IDS] = bm.as_jax(coo[0], dtype=get_idx_type())
if (POST_IDS in structures) and (POST_IDS not in structures):
all_data[POST_IDS] = bm.as_jax(coo[1], dtype=get_idx_type())
self._return_by_coo(structures, coo=coo, all_data=all_data)
# return
if len(structures) == 1:
return all_data[structures[0]]
else:
return tuple([all_data[n] for n in structures])
[docs]
def require(self, *structures):
"""Require all the connection data needed.
Examples
--------
>>> import brainpy as bp
>>> conn = bp.connect.FixedProb(0.1)
>>> mat = conn.require(10, 20, 'conn_mat')
>>> mat.shape
(10, 20)
"""
if len(structures) > 0:
pre_size = None
post_size = None
if not isinstance(structures[0], str):
pre_size = structures[0]
structures = structures[1:]
if len(structures) > 0:
if not isinstance(structures[0], str):
post_size = structures[0]
structures = structures[1:]
if pre_size is not None:
self.__call__(pre_size, post_size)
else:
return tuple()
if self.pre_num is None or self.post_num is None:
raise ConnectorError(f'self.pre_num or self.post_num is not defined. '
f'Please use "self.require(pre_size, post_size, DATA1, DATA2, ...)" ')
_has_coo_imp = not hasattr(self.build_coo, 'not_customized')
_has_csr_imp = not hasattr(self.build_csr, 'not_customized')
_has_mat_imp = not hasattr(self.build_mat, 'not_customized')
self._check(structures)
if (_has_coo_imp or _has_csr_imp or _has_mat_imp):
if len(structures) == 1:
if PRE2POST in structures and _has_csr_imp:
r = self.build_csr()
return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type())
elif CSR in structures and _has_csr_imp:
r = self.build_csr()
return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type())
elif CONN_MAT in structures and _has_mat_imp:
return bm.as_jax(self.build_mat(), dtype=MAT_DTYPE)
elif PRE_IDS in structures and _has_coo_imp:
return bm.as_jax(self.build_coo()[0], dtype=get_idx_type())
elif POST_IDS in structures and _has_coo_imp:
return bm.as_jax(self.build_coo()[1], dtype=get_idx_type())
elif COO in structures and _has_coo_imp:
r = self.build_coo()
return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type())
elif len(structures) == 2:
if (PRE_IDS in structures and POST_IDS in structures and _has_coo_imp):
r = self.build_coo()
if structures[0] == PRE_IDS:
return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type())
else:
return bm.as_jax(r[1], dtype=get_idx_type()), bm.as_jax(r[0], dtype=get_idx_type())
if ((CSR in structures or PRE2POST in structures)
and _has_csr_imp and COO in structures and _has_coo_imp):
csr = self.build_csr()
csr = (bm.as_jax(csr[0], dtype=get_idx_type()), bm.as_jax(csr[1], dtype=get_idx_type()))
coo = self.build_coo()
coo = (bm.as_jax(coo[0], dtype=get_idx_type()), bm.as_jax(coo[1], dtype=get_idx_type()))
if structures[0] == COO:
return coo, csr
else:
return csr, coo
if ((CSR in structures or PRE2POST in structures)
and _has_csr_imp and CONN_MAT in structures and _has_mat_imp):
csr = self.build_csr()
csr = (bm.as_jax(csr[0], dtype=get_idx_type()), bm.as_jax(csr[1], dtype=get_idx_type()))
mat = bm.as_jax(self.build_mat(), dtype=MAT_DTYPE)
if structures[0] == CONN_MAT:
return mat, csr
else:
return csr, mat
if (COO in structures and _has_coo_imp and CONN_MAT in structures and _has_mat_imp):
coo = self.build_coo()
coo = (bm.as_jax(coo[0], dtype=get_idx_type()), bm.as_jax(coo[1], dtype=get_idx_type()))
mat = bm.as_jax(self.build_mat(), dtype=MAT_DTYPE)
if structures[0] == COO:
return coo, mat
else:
return mat, coo
conn_data = dict(csr=None, ij=None, mat=None)
if _has_coo_imp:
conn_data['coo'] = self.build_coo()
# if (CSR in structures or PRE2POST in structures) and _has_csr_imp:
# conn_data['csr'] = self.build_csr()
# if CONN_MAT in structures and _has_mat_imp:
# conn_data['mat'] = self.build_mat()
elif _has_csr_imp:
conn_data['csr'] = self.build_csr()
# if COO in structures and _has_coo_imp:
# conn_data['coo'] = self.build_coo()
# if CONN_MAT in structures and _has_mat_imp:
# conn_data['mat'] = self.build_mat()
elif _has_mat_imp:
conn_data['mat'] = self.build_mat()
# if COO in structures and _has_coo_imp:
# conn_data['coo'] = self.build_coo()
# if (CSR in structures or PRE2POST in structures) and _has_csr_imp:
# conn_data['csr'] = self.build_csr()
else:
raise ValueError
else:
conn_data = self.build_conn()
return self._make_returns(structures, conn_data)
[docs]
def requires(self, *structures):
"""Require all the connection data needed."""
return self.require(*structures)
[docs]
@tools.not_customized
def build_conn(self):
"""build connections with certain data type.
If users want to customize their connections, please provide one
of the following functions:
- ``build_mat()``: build a matrix binary connection matrix.
- ``build_csr()``: build a csr sparse connection data.
- ``build_coo()``: build a coo sparse connection data.
- ``build_conn()``: deprecated.
Returns
-------
conn: tuple, dict
A tuple with two elements: connection type (str) and connection data.
For example: ``return 'csr', (ind, indptr)``
Or a dict with three elements: csr, mat and coo. For example:
``return dict(csr=(ind, indptr), mat=None, coo=None)``
"""
pass
[docs]
@tools.not_customized
def build_mat(self):
"""Build a binary matrix connection data.
If users want to customize their connections, please provide one
of the following functions:
- ``build_mat()``: build a matrix binary connection matrix.
- ``build_csr()``: build a csr sparse connection data.
- ``build_coo()``: build a coo sparse connection data.
- ``build_conn()``: deprecated.
Returns
-------
conn: Array
A binary matrix with the shape ``(num_pre, num_post)``.
"""
pass
[docs]
@tools.not_customized
def build_csr(self):
"""Build a csr sparse connection data.
Returns
-------
conn: tuple
A tuple denoting the ``(indices, indptr)``.
"""
pass
[docs]
@tools.not_customized
def build_coo(self):
"""Build a coo sparse connection data.
Returns
-------
conn: tuple
A tuple denoting the ``(pre_ids, post_ids)``.
"""
pass
[docs]
class OneEndConnector(TwoEndConnector):
"""Synaptic connector to build synapse connections within a population of neurons."""
def __init__(self, *args, **kwargs):
super(OneEndConnector, self).__init__(*args, **kwargs)
def __call__(self, pre_size, post_size=None):
if post_size is None:
post_size = pre_size
try:
assert pre_size == post_size
except AssertionError:
raise ConnectorError(
f'The shape of pre-synaptic group should be the same with the post group. '
f'But we got {pre_size} != {post_size}.')
if isinstance(pre_size, int):
pre_size = (pre_size,)
else:
pre_size = tuple(pre_size)
if isinstance(post_size, int):
post_size = (post_size,)
else:
post_size = tuple(post_size)
self.pre_size, self.post_size = pre_size, post_size
self.pre_num = tools.size2num(self.pre_size)
self.post_num = tools.size2num(self.post_size)
return self
def _reset_conn(self, pre_size, post_size=None):
self.__init__()
self.__call__(pre_size, post_size)
[docs]
def mat2csr(dense):
"""convert a dense matrix to (indices, indptr)."""
if isinstance(dense, onp.ndarray):
pre_ids, post_ids = onp.where(dense > 0)
else:
pre_ids, post_ids = jnp.where(bm.as_jax(dense) > 0)
return coo2csr((pre_ids, post_ids), dense.shape[0])
[docs]
def mat2coo(dense):
if isinstance(dense, onp.ndarray):
pre_ids, post_ids = onp.where(dense > 0)
else:
pre_ids, post_ids = jnp.where(bm.as_jax(dense) > 0)
return pre_ids.astype(dtype=get_idx_type()), post_ids.astype(dtype=get_idx_type())
[docs]
def mat2csc(dense):
if isinstance(dense, onp.ndarray):
pre_ids, post_ids = onp.where(dense > 0)
else:
pre_ids, post_ids = jnp.where(bm.as_jax(dense) > 0)
return coo2csr((post_ids, pre_ids), dense.shape[1])
[docs]
def csr2mat(csr, num_pre, num_post):
"""convert (indices, indptr) to a dense matrix."""
indices, indptr = csr
if isinstance(indices, onp.ndarray):
d = onp.zeros((num_pre, num_post), dtype=MAT_DTYPE) # num_pre, num_post
pre_ids = onp.repeat(onp.arange(indptr.size - 1), onp.diff(indptr))
d[pre_ids, indices] = True
return d
else:
d = bm.zeros((num_pre, num_post), dtype=MAT_DTYPE) # num_pre, num_post
pre_ids = jnp.repeat(jnp.arange(indptr.size - 1), jnp.diff(indptr))
d[pre_ids, indices] = True
return d.value
[docs]
def csr2csc(csr, post_num, data=None):
"""Convert csr to csc."""
return coo2csc(csr2coo(csr), post_num, data)
[docs]
def csr2coo(csr):
np = onp if isinstance(csr[0], onp.ndarray) else jnp
indices, indptr = csr
pre_ids = np.repeat(np.arange(indptr.size - 1), np.diff(indptr))
return pre_ids, indices
[docs]
def coo2mat(ij, num_pre, num_post):
"""convert (indices, indptr) to a dense matrix."""
pre_ids, post_ids = ij
if isinstance(pre_ids, onp.ndarray):
d = onp.zeros((num_pre, num_post), dtype=MAT_DTYPE) # num_pre, num_post
d[pre_ids, post_ids] = True
return d
else:
d = bm.zeros((num_pre, num_post), dtype=MAT_DTYPE)
d[pre_ids, post_ids] = True
return d.value
[docs]
def coo2csr(coo, num_pre):
"""convert pre_ids, post_ids to (indices, indptr) when'jax_platform_name' = 'gpu'"""
pre_ids, post_ids = coo
if isinstance(pre_ids, onp.ndarray):
sort_ids = onp.argsort(pre_ids)
post_ids = onp.asarray(post_ids)
post_ids = post_ids[sort_ids]
indices = post_ids
unique_pre_ids, pre_count = onp.unique(pre_ids, return_counts=True)
final_pre_count = onp.zeros(num_pre, dtype=jnp.uint32)
final_pre_count[unique_pre_ids] = pre_count
else:
sort_ids = onp.argsort(bm.as_jax(pre_ids))
post_ids = bm.as_jax(post_ids)
post_ids = post_ids[sort_ids]
indices = post_ids
unique_pre_ids, pre_count = jnp.unique(pre_ids, return_counts=True)
final_pre_count = bm.zeros(num_pre, dtype=jnp.uint32)
final_pre_count[unique_pre_ids] = pre_count
final_pre_count = bm.as_jax(final_pre_count)
indptr = final_pre_count.cumsum()
indptr = onp.insert(indptr, 0, 0)
return indices.astype(get_idx_type()), indptr.astype(get_idx_type())
[docs]
def coo2csc(coo, post_num, data=None):
"""Convert csr to csc."""
pre_ids, indices = coo
if isinstance(indices, onp.ndarray):
# to maintain the original order of the elements with the same value
sort_ids = onp.argsort(indices)
pre_ids_new = onp.asarray(pre_ids[sort_ids], dtype=get_idx_type())
unique_post_ids, count = onp.unique(indices, return_counts=True)
post_count = onp.zeros(post_num, dtype=get_idx_type())
post_count[unique_post_ids] = count
indptr_new = post_count.cumsum()
indptr_new = onp.insert(indptr_new, 0, 0)
indptr_new = onp.asarray(indptr_new, dtype=get_idx_type())
else:
pre_ids = bm.as_jax(pre_ids)
indices = bm.as_jax(indices)
# to maintain the original order of the elements with the same value
sort_ids = jnp.argsort(indices)
pre_ids_new = jnp.asarray(pre_ids[sort_ids], dtype=get_idx_type())
unique_post_ids, count = jnp.unique(indices, return_counts=True)
post_count = bm.zeros(post_num, dtype=get_idx_type())
post_count[unique_post_ids] = count
indptr_new = post_count.value.cumsum()
indptr_new = jnp.insert(indptr_new, 0, 0)
indptr_new = jnp.asarray(indptr_new, dtype=get_idx_type())
if data is None:
return pre_ids_new, indptr_new
else:
data_new = data[sort_ids]
return pre_ids_new, indptr_new, data_new
def coo2mat_num(ij, num_pre, num_post, num, seed=0):
"""
convert (indices, indptr) to a dense connection number matrix.\n
Specific for FixedTotalNum.
"""
rng = bm.random.RandomState(seed)
mat = coo2mat(ij, num_pre, num_post)
# get nonzero indices and number
nonzero_idx = jnp.nonzero(mat)
nonzero_num = jnp.count_nonzero(mat)
# get multi connection number
multi_conn_num = num - nonzero_num
# alter the element type to int
mat = mat.astype(jnp.int32)
# 随机在mat中选取nonzero_idx的元素,将其值加1
index = rng.choice(nonzero_num, size=(multi_conn_num,), replace=False)
for i in index:
mat = mat.at[nonzero_idx[0][i], nonzero_idx[1][i]].set(mat[nonzero_idx[0][i], nonzero_idx[1][i]] + 1)
return mat
def mat2mat_num(mat, num, seed=0):
"""
Convert boolean matrix to a dense connection number matrix.\n
Specific for FixedTotalNum.
"""
rng = bm.random.RandomState(seed)
# get nonzero indices and number
nonzero_idx = jnp.nonzero(mat)
nonzero_num = jnp.count_nonzero(mat)
# get multi connection number
multi_conn_num = num - nonzero_num
# alter the element type to int
mat = mat.astype(jnp.int32)
# 随机在mat中选取nonzero_idx的元素,将其值加1
index = rng.choice(nonzero_num, size=(multi_conn_num,), replace=False)
for i in index:
mat = mat.at[nonzero_idx[0][i], nonzero_idx[1][i]].set(mat[nonzero_idx[0][i], nonzero_idx[1][i]] + 1)
return mat
def visualizeMat(mat, description='Untitled'):
"""
Visualize the matrix. (Need seaborn and matplotlib)
parameters
----------
mat : jnp.ndarray
The matrix to be visualized.
description : str
The title of the figure.
"""
try:
import seaborn as sns
import matplotlib.pyplot as plt
except (ModuleNotFoundError, ImportError):
print('Please install seaborn and matplotlib for this function')
return
sns.heatmap(mat, cmap='viridis')
warpped_title = textwrap.fill(description, width=60)
plt.title(warpped_title)
plt.show()