Source code for brainpy._src.mixin

# -*- coding: utf-8 -*-

import numbers
import sys
import warnings
from dataclasses import dataclass
from typing import Union, Dict, Callable, Sequence, Optional, TypeVar, Any
from typing import (_SpecialForm, _type_check, _remove_dups_flatten)

import jax

from brainpy import math as bm, tools
from brainpy._src.math.object_transform.naming import get_unique_name
from brainpy.types import ArrayType

if sys.version_info.minor > 8:
  from typing import (_UnionGenericAlias)
else:
  from typing import (_GenericAlias, _tp_cache)

DynamicalSystem = None
delay_identifier, init_delay_by_return = None, None

__all__ = [
  'MixIn',
  'ParamDesc',
  'ParamDescriber',
  'DelayRegister',
  'AlignPost',
  'Container',
  'TreeNode',
  'BindCondData',
  'JointType',
  'SupportSTDP',
  'SupportAutoDelay',
  'SupportInputProj',
  'SupportOnline',
  'SupportOffline',
]


def _get_delay_tool():
  global delay_identifier, init_delay_by_return
  if init_delay_by_return is None: from brainpy._src.delay import init_delay_by_return
  if delay_identifier is None: from brainpy._src.delay import delay_identifier
  return delay_identifier, init_delay_by_return


def _get_dynsys():
  global DynamicalSystem
  if DynamicalSystem is None: from brainpy._src.dynsys import DynamicalSystem
  return DynamicalSystem


[docs] class MixIn(object): """Base MixIn object. The key for a :py:class:`~.MixIn` is that: no initialization function, only behavioral functions. """ pass
[docs] class ParamDesc(MixIn): """:py:class:`~.MixIn` indicates the function for describing initialization parameters. This mixin enables the subclass has a classmethod ``desc``, which produces an instance of :py:class:`~.ParamDescInit`. Note this MixIn can be applied in any Python object. """ not_desc_params: Optional[Sequence[str]] = None @classmethod def desc(cls, *args, **kwargs) -> 'ParamDescriber': return ParamDescriber(cls, *args, **kwargs)
class ParamDescriber(object): """Delayed initialization for parameter describers. """ def __init__(self, cls: type, *desc_tuple, **desc_dict): self.cls = cls # arguments self.args = desc_tuple self.kwargs = desc_dict # identifier if isinstance(cls, _JointGenericAlias): name = str(cls) repr_kwargs = {k: v for k, v in desc_dict.items()} else: assert isinstance(cls, type) if issubclass(cls, ParamDesc) and (cls.not_desc_params is not None): repr_kwargs = {k: v for k, v in desc_dict.items() if k not in cls.not_desc_params} else: repr_kwargs = {k: v for k, v in desc_dict.items()} name = cls.__name__ for k in tuple(repr_kwargs.keys()): if isinstance(repr_kwargs[k], bm.Variable): repr_kwargs[k] = id(repr_kwargs[k]) repr_args = tools.repr_dict(repr_kwargs) if len(desc_tuple): repr_args = f"{', '.join([repr(arg) for arg in desc_tuple])}, {repr_args}" self._identifier = f'{name}({repr_args})' def __call__(self, *args, **kwargs): return self.cls(*self.args, *args, **self.kwargs, **kwargs) def init(self, *args, **kwargs): return self.__call__(*args, **kwargs) def __instancecheck__(self, instance): if not isinstance(instance, ParamDescriber): return False if not issubclass(instance.cls, self.cls): return False return True @classmethod def __class_getitem__(cls, item: type): return ParamDescriber(item) @property def identifier(self): return self._identifier @identifier.setter def identifier(self, value): self._identifier = value
[docs] class AlignPost(MixIn): """Align post MixIn. This class provides a ``add_current()`` function for add external currents. """ def add_current(self, *args, **kwargs): raise NotImplementedError
@dataclass class ReturnInfo: size: Sequence[int] axis_names: Optional[Sequence[str]] = None batch_or_mode: Optional[Union[int, bm.Mode]] = None data: Union[Callable, bm.Array, jax.Array] = bm.zeros def get_data(self): if isinstance(self.data, Callable): if isinstance(self.batch_or_mode, int): size = (self.batch_or_mode,) + tuple(self.size) elif isinstance(self.batch_or_mode, bm.NonBatchingMode): size = tuple(self.size) elif isinstance(self.batch_or_mode, bm.BatchingMode): size = (self.batch_or_mode.batch_size,) + tuple(self.size) else: size = tuple(self.size) init = self.data(size) elif isinstance(self.data, (bm.Array, jax.Array)): init = self.data else: raise ValueError return init
[docs] class Container(MixIn): """Container :py:class:`~.MixIn` which wrap a group of objects. """ children: bm.node_dict def __getitem__(self, item): """Overwrite the slice access (`self['']`). """ if item in self.children: return self.children[item] else: raise ValueError(f'Unknown item {item}, we only found {list(self.children.keys())}') def __getattr__(self, item): """Overwrite the dot access (`self.`). """ if item == 'children': return super().__getattribute__('children') else: children = super().__getattribute__('children') if item in children: return children[item] else: return super().__getattribute__(item) def __repr__(self): cls_name = self.__class__.__name__ indent = ' ' * len(cls_name) child_str = [tools.repr_context(repr(val), indent) for val in self.children.values()] string = ", \n".join(child_str) return f'{cls_name}({string})' def __get_elem_name(self, elem): if isinstance(elem, bm.BrainPyObject): return elem.name else: return get_unique_name('ContainerElem') def format_elements(self, child_type: type, *children_as_tuple, **children_as_dict): res = dict() # add tuple-typed components for module in children_as_tuple: if isinstance(module, child_type): res[self.__get_elem_name(module)] = module elif isinstance(module, (list, tuple)): for m in module: if not isinstance(m, child_type): raise ValueError(f'Should be instance of {child_type.__name__}. ' f'But we got {type(m)}') res[self.__get_elem_name(m)] = m elif isinstance(module, dict): for k, v in module.items(): if not isinstance(v, child_type): raise ValueError(f'Should be instance of {child_type.__name__}. ' f'But we got {type(v)}') res[k] = v else: raise ValueError(f'Cannot parse sub-systems. They should be {child_type.__name__} ' f'or a list/tuple/dict of {child_type.__name__}.') # add dict-typed components for k, v in children_as_dict.items(): if not isinstance(v, child_type): raise ValueError(f'Should be instance of {child_type.__name__}. ' f'But we got {type(v)}') res[k] = v return res
[docs] def add_elem(self, *elems, **elements): """Add new elements. >>> obj = Container() >>> obj.add_elem(a=1.) Args: elements: children objects. """ self.children.update(self.format_elements(object, *elems, **elements))
[docs] class TreeNode(MixIn): """Tree node. """ master_type: type def check_hierarchies(self, root, *leaves, **named_leaves): global DynamicalSystem if DynamicalSystem is None: from brainpy._src.dynsys import DynamicalSystem for leaf in leaves: if isinstance(leaf, DynamicalSystem): self.check_hierarchy(root, leaf) elif isinstance(leaf, (list, tuple)): self.check_hierarchies(root, *leaf) elif isinstance(leaf, dict): self.check_hierarchies(root, **leaf) else: raise ValueError(f'Do not support {type(leaf)}.') for leaf in named_leaves.values(): if not isinstance(leaf, DynamicalSystem): raise ValueError(f'Do not support {type(leaf)}. Must be instance of {DynamicalSystem.__name__}') self.check_hierarchy(root, leaf) def check_hierarchy(self, root, leaf): if hasattr(leaf, 'master_type'): master_type = leaf.master_type else: raise ValueError('Child class should define "master_type" to ' 'specify the type of the root node. ' f'But we did not found it in {leaf}') if not issubclass(root, master_type): raise TypeError(f'Type does not match. {leaf} requires a master with type ' f'of {leaf.master_type}, but the master now is {root}.')
class DelayRegister(MixIn): def register_delay( self, identifier: str, delay_step: Optional[Union[int, ArrayType, Callable]], delay_target: bm.Variable, initial_delay_data: Union[Callable, ArrayType, numbers.Number] = None, ): """Register delay variable. Args: identifier: str. The delay access name. delay_target: The target variable for delay. delay_step: The delay time step. initial_delay_data: The initializer for the delay data. Returns: delay_pos: The position of the delay. """ _delay_identifier, _init_delay_by_return = _get_delay_tool() DynamicalSystem = _get_dynsys() assert isinstance(self, DynamicalSystem), f'self must be an instance of {DynamicalSystem.__name__}' _delay_identifier = _delay_identifier + identifier if not self.has_aft_update(_delay_identifier): self.add_aft_update(_delay_identifier, _init_delay_by_return(delay_target, initial_delay_data)) delay_cls = self.get_aft_update(_delay_identifier) name = get_unique_name('delay') delay_cls.register_entry(name, delay_step) return name def get_delay_data( self, identifier: str, delay_pos: str, *indices: Union[int, slice, bm.Array, jax.Array], ): """Get delay data according to the provided delay steps. Parameters ---------- identifier: str The delay variable name. delay_pos: str The delay length. indices: optional, int, slice, ArrayType The indices of the delay. Returns ------- delay_data: ArrayType The delay data at the given time. """ _delay_identifier, _init_delay_by_return = _get_delay_tool() _delay_identifier = _delay_identifier + identifier delay_cls = self.get_aft_update(_delay_identifier) return delay_cls.at(delay_pos, *indices) def update_local_delays(self, nodes: Union[Sequence, Dict] = None): """Update local delay variables. This function should be called after updating neuron groups or delay sources. For example, in a network model, Parameters ---------- nodes: sequence, dict The nodes to update their delay variables. """ warnings.warn('.update_local_delays() has been removed since brainpy>=2.4.6', DeprecationWarning) def reset_local_delays(self, nodes: Union[Sequence, Dict] = None): """Reset local delay variables. Parameters ---------- nodes: sequence, dict The nodes to Reset their delay variables. """ warnings.warn('.reset_local_delays() has been removed since brainpy>=2.4.6', DeprecationWarning) def get_delay_var(self, name): _delay_identifier, _init_delay_by_return = _get_delay_tool() _delay_identifier = _delay_identifier + name delay_cls = self.get_aft_update(_delay_identifier) return delay_cls class SupportInputProj(MixIn): """The :py:class:`~.MixIn` that receives the input projections. Note that the subclass should define a ``cur_inputs`` attribute. Otherwise, the input function utilities cannot be used. """ current_inputs: bm.node_dict delta_inputs: bm.node_dict def add_inp_fun(self, key: str, fun: Callable, label: Optional[str] = None, category: str = 'current'): """Add an input function. Args: key: str. The dict key. fun: Callable. The function to generate inputs. label: str. The input label. category: str. The input category, should be ``current`` (the current) or ``delta`` (the delta synapse, indicating the delta function). """ if not callable(fun): raise TypeError('Must be a function.') key = self._input_label_repr(key, label) if category == 'current': if key in self.current_inputs: raise ValueError(f'Key "{key}" has been defined and used.') self.current_inputs[key] = fun elif category == 'delta': if key in self.delta_inputs: raise ValueError(f'Key "{key}" has been defined and used.') self.delta_inputs[key] = fun else: raise NotImplementedError(f'Unknown category: {category}. Only support "current" and "delta".') def get_inp_fun(self, key: str): """Get the input function. Args: key: str. The key. Returns: The input function which generates currents. """ if key in self.current_inputs: return self.current_inputs[key] elif key in self.delta_inputs: return self.delta_inputs[key] else: raise ValueError(f'Unknown key: {key}') def sum_current_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs): """Summarize all current inputs by the defined input functions ``.current_inputs``. Args: *args: The arguments for input functions. init: The initial input data. label: str. The input label. **kwargs: The arguments for input functions. Returns: The total currents. """ if label is None: for key, out in self.current_inputs.items(): init = init + out(*args, **kwargs) else: label_repr = self._input_label_start(label) for key, out in self.current_inputs.items(): if key.startswith(label_repr): init = init + out(*args, **kwargs) return init def sum_delta_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs): """Summarize all delta inputs by the defined input functions ``.delta_inputs``. Args: *args: The arguments for input functions. init: The initial input data. label: str. The input label. **kwargs: The arguments for input functions. Returns: The total currents. """ if label is None: for key, out in self.delta_inputs.items(): init = init + out(*args, **kwargs) else: label_repr = self._input_label_start(label) for key, out in self.delta_inputs.items(): if key.startswith(label_repr): init = init + out(*args, **kwargs) return init @classmethod def _input_label_start(cls, label: str): # unify the input label repr. return f'{label} // ' @classmethod def _input_label_repr(cls, name: str, label: Optional[str] = None): # unify the input label repr. return name if label is None else (cls._input_label_start(label) + str(name)) # deprecated # # ---------- # @property def cur_inputs(self): return self.current_inputs def sum_inputs(self, *args, **kwargs): warnings.warn('Please use ".sum_current_inputs()" instead. ".sum_inputs()" will be removed.', UserWarning) return self.sum_current_inputs(*args, **kwargs) class SupportReturnInfo(MixIn): """``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`.""" def return_info(self) -> Union[bm.Variable, ReturnInfo]: raise NotImplementedError('Must implement the "return_info()" function.') class SupportAutoDelay(SupportReturnInfo): pass class SupportOnline(MixIn): """:py:class:`~.MixIn` to support the online training methods. .. versionadded:: 2.4.5 """ online_fit_by: Optional # methods for online fitting def online_init(self, *args, **kwargs): raise NotImplementedError def online_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]): raise NotImplementedError class SupportOffline(MixIn): """:py:class:`~.MixIn` to support the offline training methods. .. versionadded:: 2.4.5 """ offline_fit_by: Optional # methods for offline fitting def offline_init(self, *args, **kwargs): pass def offline_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]): raise NotImplementedError
[docs] class BindCondData(MixIn): """Bind temporary conductance data. """ _conductance: Optional def bind_cond(self, conductance): self._conductance = conductance def unbind_cond(self): self._conductance = None
class SupportSTDP(MixIn): """Support synaptic plasticity by modifying the weights. """ def stdp_update(self, *args, on_pre=None, onn_post=None, **kwargs): raise NotImplementedError T = TypeVar('T') def get_type(types): class NewType(type): def __instancecheck__(self, other): cls_of_other = other.__class__ return all([issubclass(cls_of_other, cls) for cls in types]) return NewType class _MetaUnionType(type): def __new__(cls, name, bases, dct): if isinstance(bases, type): bases = (bases,) elif isinstance(bases, (list, tuple)): bases = tuple(bases) for base in bases: assert isinstance(base, type), f'Must be type. But got {base}' else: raise TypeError(f'Must be type. But got {bases}') return super().__new__(cls, name, bases, dct) def __instancecheck__(self, other): cls_of_other = other.__class__ return all([issubclass(cls_of_other, cls) for cls in self.__bases__]) def __subclasscheck__(self, subclass): return all([issubclass(subclass, cls) for cls in self.__bases__]) if sys.version_info.minor > 8: class _JointGenericAlias(_UnionGenericAlias, _root=True): def __subclasscheck__(self, subclass): return all([issubclass(subclass, cls) for cls in set(self.__args__)]) @_SpecialForm def JointType(self, parameters): """Joint type; JointType[X, Y] means both X and Y. To define a union, use e.g. Union[int, str]. Details: - The arguments must be types and there must be at least one. - None as an argument is a special case and is replaced by `type(None)`. - Unions of unions are flattened, e.g.:: JointType[JointType[int, str], float] == JointType[int, str, float] - Unions of a single argument vanish, e.g.:: JointType[int] == int # The constructor actually returns int - Redundant arguments are skipped, e.g.:: JointType[int, str, int] == JointType[int, str] - When comparing unions, the argument order is ignored, e.g.:: JointType[int, str] == JointType[str, int] - You cannot subclass or instantiate a union. - You can use Optional[X] as a shorthand for JointType[X, None]. """ if parameters == (): raise TypeError("Cannot take a Joint of no types.") if not isinstance(parameters, tuple): parameters = (parameters,) msg = "JointType[arg, ...]: each arg must be a type." parameters = tuple(_type_check(p, msg) for p in parameters) parameters = _remove_dups_flatten(parameters) if len(parameters) == 1: return parameters[0] return _JointGenericAlias(self, parameters) else: class _JointGenericAlias(_GenericAlias, _root=True): def __subclasscheck__(self, subclass): return all([issubclass(subclass, cls) for cls in set(self.__args__)]) class _SpecialForm2(_SpecialForm, _root=True): @_tp_cache def __getitem__(self, parameters): if self._name == 'JointType': if parameters == (): raise TypeError("Cannot take a Joint of no types.") if not isinstance(parameters, tuple): parameters = (parameters,) msg = "JointType[arg, ...]: each arg must be a type." parameters = tuple(_type_check(p, msg) for p in parameters) parameters = _remove_dups_flatten(parameters) if len(parameters) == 1: return parameters[0] return _JointGenericAlias(self, parameters) else: return super().__getitem__(parameters) JointType = _SpecialForm2( 'JointType', doc="""Joint type; JointType[X, Y] means both X and Y. To define a joint, use e.g. JointType[int, str]. Details: - The arguments must be types and there must be at least one. - None as an argument is a special case and is replaced by `type(None)`. - Unions of unions are flattened, e.g.:: JointType[JointType[int, str], float] == JointType[int, str, float] - Unions of a single argument vanish, e.g.:: JointType[int] == int # The constructor actually returns int - Redundant arguments are skipped, e.g.:: JointType[int, str, int] == JointType[int, str] - When comparing unions, the argument order is ignored, e.g.:: JointType[int, str] == JointType[str, int] - You cannot subclass or instantiate a union. - You can use Optional[X] as a shorthand for JointType[X, None]. """ )