# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Union, Dict, Optional
import jax
import numpy as np
from brainpy import math as bm
from brainpy.connect import TwoEndConnector, MatConn, IJConn
from brainpy.dynsys import Projection, DynamicalSystem
from brainpy.types import ArrayType
__all__ = [
'SynConn',
]
[docs]
class SynConn(Projection):
"""Base class to model two-end synaptic connections.
Parameters::
pre : NeuGroup
Pre-synaptic neuron group.
post : NeuGroup
Post-synaptic neuron group.
conn : optional, ndarray, ArrayType, dict, TwoEndConnector
The connection method between pre- and post-synaptic groups.
name : str, optional
The name of the dynamic system.
"""
def __init__(
self,
pre: DynamicalSystem,
post: DynamicalSystem,
conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]] = None,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
super().__init__(name=name, mode=mode)
# pre or post neuron group
# ------------------------
if not isinstance(pre, DynamicalSystem):
raise TypeError('"pre" must be an instance of DynamicalSystem.')
if not isinstance(post, DynamicalSystem):
raise TypeError('"post" must be an instance of DynamicalSystem.')
self.pre = pre
self.post = post
# connectivity
# ------------
if isinstance(conn, TwoEndConnector):
self.conn = conn(pre.size, post.size)
elif isinstance(conn, (bm.Array, np.ndarray, jax.Array)):
if (pre.num, post.num) != conn.shape:
raise ValueError(f'"conn" is provided as a matrix, and it is expected '
f'to be an array with shape of (pre.num, post.num) = '
f'{(pre.num, post.num)}, however we got {conn.shape}')
self.conn = MatConn(conn_mat=conn)
elif isinstance(conn, dict):
if not ('i' in conn and 'j' in conn):
raise ValueError(f'"conn" is provided as a dict, and it is expected to '
f'be a dictionary with "i" and "j" specification, '
f'however we got {conn}')
self.conn = IJConn(i=conn['i'], j=conn['j'])
elif isinstance(conn, str):
self.conn = conn
elif conn is None:
self.conn = None
else:
raise ValueError(f'Unknown "conn" type: {conn}')
def __repr__(self):
names = self.__class__.__name__
return (f'{names}(name={self.name}, mode={self.mode}, \n'
f'{" " * len(names)} pre={self.pre}, \n'
f'{" " * len(names)} post={self.post})')
[docs]
def check_pre_attrs(self, *attrs):
"""Check whether pre group satisfies the requirement."""
if not hasattr(self, 'pre'):
raise ValueError('Please call __init__ function first.')
for attr in attrs:
if not isinstance(attr, str):
raise TypeError(f'Must be string. But got {attr}.')
if not hasattr(self.pre, attr):
raise ValueError(f'{self} need "pre" neuron group has attribute "{attr}".')
[docs]
def check_post_attrs(self, *attrs):
"""Check whether post group satisfies the requirement."""
if not hasattr(self, 'post'):
raise ValueError('Please call __init__ function first.')
for attr in attrs:
if not isinstance(attr, str):
raise TypeError(f'Must be string. But got {attr}.')
if not hasattr(self.post, attr):
raise ValueError(f'{self} need "post" neuron group has attribute "{attr}".')
[docs]
def update(self, *args, **kwargs):
"""The function to specify the updating rule.
Assume any dynamical system depends on the shared variables (`sha`),
like time variable ``t``, the step precision ``dt``, and the time step `i`.
"""
raise NotImplementedError('Must implement "update" function by subclass self.')