# -*- coding: utf-8 -*-
from typing import Union, Optional, Dict, Sequence, Callable
import brainpy.math as bm
from brainpy._src.dyn.base import IonChaDyn
from brainpy._src.dyn.neurons.hh import HHTypedNeuron
from brainpy._src.mixin import Container, TreeNode, _JointGenericAlias
from brainpy.types import Shape
__all__ = [
'MixIons',
'mix_ions',
'Ion',
]
[docs]
class MixIons(IonChaDyn, Container, TreeNode):
"""Mixing Ions.
Args:
ions: Instances of ions. This option defines the master types of all children objects.
channels: Instance of channels.
"""
master_type = HHTypedNeuron
def __init__(self, *ions, **channels):
# TODO: check "ions" should be independent from each other
assert isinstance(ions, (tuple, list)), f'{self.__class__.__name__} requires at least two ions. '
assert len(ions) >= 2, f'{self.__class__.__name__} requires at least two ions. '
assert all([isinstance(cls, Ion) for cls in ions]), f'Must be a sequence of Ion. But got {ions}.'
super().__init__(size=ions[0].size, keep_size=ions[0].keep_size, sharding=ions[0].sharding)
# Attribute of "Container"
self.children = bm.node_dict()
self.ions: Sequence['Ion'] = tuple(ions)
self._ion_classes = tuple([type(ion) for ion in self.ions])
for k, v in channels.items():
self.add_elem(k=v)
def update(self, V):
nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values())
self.check_hierarchies(self._ion_classes, *nodes)
for node in nodes:
infos = tuple([self._get_imp(root).pack_info() for root in node.master_type.__args__])
node.update(V, *infos)
def current(self, V):
"""Generate ion channel current.
Args:
V: The membrane potential.
Returns:
Current.
"""
nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values())
self.check_hierarchies(self._ion_classes, *nodes)
if len(nodes) == 0:
return 0.
else:
current = 0.
for node in nodes:
infos = tuple([self._get_imp(root).pack_info() for root in node.master_type.__args__])
current = current + node.current(V, *infos)
return current
def reset_state(self, V, batch_size=None):
nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values())
self.check_hierarchies(self._ion_classes, *nodes)
for node in nodes:
infos = tuple([self._get_imp(root).pack_info() for root in node.master_type.__args__])
node.reset_state(V, *infos, batch_size)
def check_hierarchy(self, roots, leaf):
# 'master_type' should be a brainpy.mixin.JointType
self._check_master_type(leaf)
for cls in leaf.master_type.__args__:
if not any([issubclass(root, cls) for root in roots]):
raise TypeError(f'Type does not match. {leaf} requires a master with type '
f'of {leaf.master_type}, but the master type now is {roots}.')
def add_elem(self, *elems, **elements):
"""Add new elements.
Args:
elements: children objects.
"""
self.check_hierarchies(self._ion_classes, *elems, **elements)
self.children.update(self.format_elements(IonChaDyn, *elems, **elements))
for elem in tuple(elems) + tuple(elements.values()):
for ion_root in elem.master_type.__args__:
ion = self._get_imp(ion_root)
ion.add_external_current(elem.name, self._get_ion_fun(ion, elem))
def _get_ion_fun(self, ion, node):
def fun(V, *args):
infos = tuple([(ion.pack_info(*args)
if isinstance(ion, root) else
self._get_imp(root).pack_info())
for root in node.master_type.__args__])
return node.current(V, *infos)
return fun
def _get_imp(self, cls):
for ion in self.ions:
if isinstance(ion, cls):
return ion
else:
raise ValueError(f'No instance of {cls} is found.')
def _check_master_type(self, leaf):
if not isinstance(leaf.master_type, _JointGenericAlias):
raise TypeError(f'{self.__class__.__name__} requires leaf nodes that have the master_type of '
f'"brainpy.mixin.JointType". However, we got {leaf.master_type}')
[docs]
def mix_ions(*ions) -> MixIons:
"""Create mixed ions.
Args:
ions: Ion instances.
Returns:
Instance of MixIons.
"""
for ion in ions:
assert isinstance(ion, Ion), f'Must be instance of {Ion.__name__}. But got {type(ion)}'
assert len(ions) > 0, ''
return MixIons(*ions)
[docs]
class Ion(IonChaDyn, Container, TreeNode):
"""The brainpy_object calcium dynamics.
Args:
size: The size of the simulation target.
method: The numerical integration method.
name: The name of the object.
channels: The calcium dependent channels.
"""
'''The type of the master object.'''
master_type = HHTypedNeuron
"""Reversal potential."""
E: Union[float, bm.Variable, bm.Array]
"""Calcium concentration."""
C: Union[float, bm.Variable, bm.Array]
def __init__(
self,
size: Shape,
keep_size: bool = False,
method: str = 'exp_auto',
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
**channels
):
super().__init__(size, keep_size=keep_size, mode=mode, method=method, name=name)
# Attribute of "Container"
self.children = bm.node_dict(self.format_elements(IonChaDyn, **channels))
self.external: Dict[str, Callable] = dict() # not found by `.nodes()` or `.vars()`
def update(self, V):
for node in self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values():
node.update(V, self.C, self.E)
def current(self, V, C=None, E=None, external: bool = False):
"""Generate ion channel current.
Args:
V: The membrane potential.
C: The given ion concentration.
E: The given reversal potential.
external: Include the external current.
Returns:
Current.
"""
C = self.C if (C is None) else C
E = self.E if (E is None) else E
nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values())
self.check_hierarchies(type(self), *nodes)
current = 0.
if len(nodes) > 0:
for node in nodes:
current = current + node.current(V, C, E)
if external:
for key, node in self.external.items():
current = current + node(V, C, E)
return current
def reset_state(self, V, batch_size=None):
nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values())
self.check_hierarchies(type(self), *nodes)
for node in nodes:
node.reset_state(V, self.C, self.E, batch_size)
def pack_info(self, C=None, E=None) -> Dict:
if C is None:
C = self.C
if E is None:
E = self.E
return dict(C=C, E=E)
def add_external_current(self, key: str, fun: Callable):
if key in self.external:
raise ValueError
self.external[key] = fun