Source code for brainpy._src.dynold.synapses.gap_junction
# -*- coding: utf-8 -*-
from typing import Union, Dict, Callable
import brainpy.math as bm
from brainpy._src.dyn.base import NeuDyn
from brainpy._src.connect import TwoEndConnector
from brainpy._src.dynold.synapses import TwoEndConn
from brainpy._src.initialize import Initializer, parameter
from brainpy.types import ArrayType
__all__ = [
'GapJunction',
]
[docs]
class GapJunction(TwoEndConn):
[docs]
def __init__(
self,
pre: NeuDyn,
post: NeuDyn,
conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
comp_method: str = 'dense',
g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
name: str = None,
):
super(GapJunction, self).__init__(pre=pre,
post=post,
conn=conn,
name=name)
# checking
self.check_pre_attrs('V')
self.check_post_attrs('V', 'input')
# assert isinstance(self.output, _NullSynOut)
# assert isinstance(self.stp, _NullSynSTP)
# connections
self.comp_method = comp_method
if comp_method == 'dense':
self.conn_mat = self.conn.require('conn_mat')
self.weights = parameter(g_max, (pre.num, post.num), allow_none=False)
elif comp_method == 'sparse':
self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')
self.weights = parameter(g_max, self.pre_ids.shape, allow_none=False)
else:
raise ValueError
def update(self):
if self.comp_method == 'dense':
# pre -> post
diff = (self.pre.V.reshape((-1, 1)) - self.post.V) * self.conn_mat * self.weights
self.post.input += bm.einsum('ij->j', diff)
# post -> pre
self.pre.input += bm.einsum('ij->i', -diff)
else:
diff = (self.pre.V[self.pre_ids] - self.post.V[self.post_ids]) * self.weights
self.post.input += bm.syn2post_sum(diff, self.post_ids, self.post.num)
self.pre.input += bm.syn2post_sum(-diff, self.pre_ids, self.pre.num)
def reset_state(self, batch_size=None):
pass