Source code for brainpy.mixin

# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import warnings
from dataclasses import dataclass
from typing import Union, Dict, Callable, Sequence, Optional, Any

import brainstate
import jax

bm, delay_identifier, init_delay_by_return, DynamicalSystem = None, None, None, None

__all__ = [
    'MixIn',
    'ParamDesc',
    'ParamDescriber',
    'AlignPost',
    'Container',
    'TreeNode',
    'BindCondData',
    'JointType',
    'SupportSTDP',
    'SupportAutoDelay',
    'SupportInputProj',
    'SupportOnline',
    'SupportOffline',
]

MixIn = brainstate.mixin.Mixin
ParamDesc = brainstate.mixin.ParamDesc
ParamDescriber = brainstate.mixin.ParamDescriber
JointType = brainstate.mixin.JointTypes


def _get_bm():
    global bm
    if bm is None:
        from brainpy import math
        bm = math
    return bm


[docs] class AlignPost(brainstate.mixin.Mixin): """ Mixin for aligning post-synaptic inputs. This mixin provides an interface for components that need to receive and process post-synaptic inputs, such as synaptic connections or neural populations. The ``align_post_input_add`` method should be implemented to handle the accumulation of external currents or inputs. Notes ----- Classes that inherit from this mixin must implement the ``align_post_input_add`` method. Examples -------- Implementing a synapse with post-synaptic alignment: .. code-block:: python >>> import brainstate >>> import jax.numpy as jnp >>> >>> class Synapse(brainstate.mixin.AlignPost): ... def __init__(self, weight): ... self.weight = weight ... self.post_current = brainstate.State(0.0) ... ... def align_post_input_add(self, current): ... # Accumulate the weighted current into post-synaptic target ... self.post_current.value += current * self.weight >>> >>> # Usage >>> synapse = Synapse(weight=0.5) >>> synapse.align_post_input_add(10.0) >>> print(synapse.post_current.value) # Output: 5.0 Using with neural populations: .. code-block:: python >>> class NeuronGroup(brainstate.mixin.AlignPost): ... def __init__(self, size): ... self.size = size ... self.input_current = brainstate.State(jnp.zeros(size)) ... ... def align_post_input_add(self, current): ... # Add external current to neurons ... self.input_current.value = self.input_current.value + current >>> >>> neurons = NeuronGroup(100) >>> external_input = jnp.ones(100) * 0.5 >>> neurons.align_post_input_add(external_input) """
[docs] def align_post_input_add(self, *args, **kwargs): """ Add external inputs to the post-synaptic component. Parameters ---------- *args Positional arguments for the input. **kwargs Keyword arguments for the input. Raises ------ NotImplementedError If the method is not implemented by the subclass. """ raise NotImplementedError
[docs] class BindCondData(brainstate.mixin.Mixin): """ Mixin for binding temporary conductance data. This mixin provides an interface for temporarily storing conductance data, which is useful in synaptic models where conductance values need to be passed between computation steps without being part of the permanent state. Attributes ---------- _conductance : Any, optional Temporarily bound conductance data. Examples -------- Using conductance binding in a synapse: .. code-block:: python >>> import brainstate >>> import jax.numpy as jnp >>> >>> class ConductanceBasedSynapse(brainstate.mixin.BindCondData): ... def __init__(self): ... self._conductance = None ... ... def compute(self, pre_spike): ... if pre_spike: ... # Bind conductance data temporarily ... self.bind_cond(0.5) ... ... # Use conductance if available ... if self._conductance is not None: ... current = self._conductance * (0.0 - (-70.0)) ... # Clear after use ... self.unbind_cond() ... return current ... return 0.0 >>> >>> synapse = ConductanceBasedSynapse() >>> current = synapse.compute(pre_spike=True) Managing conductance in a network: .. code-block:: python >>> class SynapticConnection(brainstate.mixin.BindCondData): ... def __init__(self, g_max): ... self.g_max = g_max ... self._conductance = None ... ... def prepare_conductance(self, activation): ... # Bind conductance based on activation ... g = self.g_max * activation ... self.bind_cond(g) ... ... def apply_conductance(self, voltage): ... if self._conductance is not None: ... current = self._conductance * voltage ... self.unbind_cond() ... return current ... return 0.0 """ # Attribute to store temporary conductance data _conductance: Optional
[docs] def bind_cond(self, conductance): """ Bind conductance data temporarily. Parameters ---------- conductance : Any The conductance data to bind. """ self._conductance = conductance
[docs] def unbind_cond(self): """ Unbind (clear) the conductance data. """ self._conductance = None
def _get_delay_tool(): global delay_identifier, init_delay_by_return if init_delay_by_return is None: from brainpy.delay import init_delay_by_return if delay_identifier is None: from brainpy.delay import delay_identifier return delay_identifier, init_delay_by_return @dataclass class ReturnInfo: size: Sequence[int] axis_names: Optional[Sequence[str]] = None batch_or_mode: Optional[Union[int, brainstate.mixin.Mode]] = None data: Union[Callable, jax.Array] = jax.numpy.zeros def get_data(self): bm = _get_bm() if isinstance(self.data, Callable): if isinstance(self.batch_or_mode, int): size = (self.batch_or_mode,) + tuple(self.size) elif isinstance(self.batch_or_mode, bm.NonBatchingMode): size = tuple(self.size) elif isinstance(self.batch_or_mode, bm.BatchingMode): size = (self.batch_or_mode.batch_size,) + tuple(self.size) else: size = tuple(self.size) init = self.data(size) elif isinstance(self.data, (bm.Array, jax.Array)): init = self.data else: raise ValueError return init
[docs] class Container(MixIn): """Container :py:class:`~.MixIn` which wrap a group of objects. """ children: dict() def __getitem__(self, item): """Overwrite the slice access (`self['']`). """ if item in self.children: return self.children[item] else: raise ValueError(f'Unknown item {item}, we only found {list(self.children.keys())}') def __getattr__(self, item): """Overwrite the dot access (`self.`). """ if item == 'children': return super().__getattribute__('children') else: children = super().__getattribute__('children') if item in children: return children[item] else: return super().__getattribute__(item) def __repr__(self): from brainpy import tools cls_name = self.__class__.__name__ indent = ' ' * len(cls_name) child_str = [tools.repr_context(repr(val), indent) for val in self.children.values()] string = ", \n".join(child_str) return f'{cls_name}({string})' def __get_elem_name(self, elem): bm = _get_bm() if isinstance(elem, bm.BrainPyObject): return elem.name else: from brainpy.math.object_transform.base import get_unique_name return get_unique_name('ContainerElem') def format_elements(self, child_type: type, *children_as_tuple, **children_as_dict): res = dict() # add tuple-typed components for module in children_as_tuple: if isinstance(module, child_type): res[self.__get_elem_name(module)] = module elif isinstance(module, (list, tuple)): for m in module: if not isinstance(m, child_type): raise ValueError(f'Should be instance of {child_type.__name__}. ' f'But we got {type(m)}') res[self.__get_elem_name(m)] = m elif isinstance(module, dict): for k, v in module.items(): if not isinstance(v, child_type): raise ValueError(f'Should be instance of {child_type.__name__}. ' f'But we got {type(v)}') res[k] = v else: raise ValueError(f'Cannot parse sub-systems. They should be {child_type.__name__} ' f'or a list/tuple/dict of {child_type.__name__}.') # add dict-typed components for k, v in children_as_dict.items(): if not isinstance(v, child_type): raise ValueError(f'Should be instance of {child_type.__name__}. ' f'But we got {type(v)}') res[k] = v return res
[docs] def add_elem(self, *elems, **elements): """Add new elements. >>> obj = Container() >>> obj.add_elem(a=1.) Args: elements: children objects. """ self.children.update(self.format_elements(object, *elems, **elements))
[docs] class TreeNode(MixIn): """Tree node. """ master_type: type def check_hierarchies(self, root, *leaves, **named_leaves): global DynamicalSystem if DynamicalSystem is None: from brainpy.dynsys import DynamicalSystem for leaf in leaves: if isinstance(leaf, DynamicalSystem): self.check_hierarchy(root, leaf) elif isinstance(leaf, (list, tuple)): self.check_hierarchies(root, *leaf) elif isinstance(leaf, dict): self.check_hierarchies(root, **leaf) else: raise ValueError(f'Do not support {type(leaf)}.') for leaf in named_leaves.values(): if not isinstance(leaf, DynamicalSystem): raise ValueError(f'Do not support {type(leaf)}. Must be instance of {DynamicalSystem.__name__}') self.check_hierarchy(root, leaf) def check_hierarchy(self, root, leaf): if hasattr(leaf, 'master_type'): master_type = leaf.master_type else: raise ValueError('Child class should define "master_type" to ' 'specify the type of the root node. ' f'But we did not found it in {leaf}') if not issubclass(root, master_type): raise TypeError(f'Type does not match. {leaf} requires a master with type ' f'of {leaf.master_type}, but the master now is {root}.')
class SupportInputProj(MixIn): """The :py:class:`~.MixIn` that receives the input projections. Note that the subclass should define a ``cur_inputs`` attribute. Otherwise, the input function utilities cannot be used. """ current_inputs: dict delta_inputs: dict def add_inp_fun(self, key: str, fun: Callable, label: Optional[str] = None, category: str = 'current'): """Add an input function. Args: key: str. The dict key. fun: Callable. The function to generate inputs. label: str. The input label. category: str. The input category, should be ``current`` (the current) or ``delta`` (the delta synapse, indicating the delta function). """ if not callable(fun): raise TypeError('Must be a function.') key = self._input_label_repr(key, label) if category == 'current': if key in self.current_inputs: raise ValueError(f'Key "{key}" has been defined and used.') self.current_inputs[key] = fun elif category == 'delta': if key in self.delta_inputs: raise ValueError(f'Key "{key}" has been defined and used.') self.delta_inputs[key] = fun else: raise NotImplementedError(f'Unknown category: {category}. Only support "current" and "delta".') def get_inp_fun(self, key: str): """Get the input function. Args: key: str. The key. Returns: The input function which generates currents. """ if key in self.current_inputs: return self.current_inputs[key] elif key in self.delta_inputs: return self.delta_inputs[key] else: raise ValueError(f'Unknown key: {key}') def sum_current_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs): """Summarize all current inputs by the defined input functions ``.current_inputs``. Args: *args: The arguments for input functions. init: The initial input data. label: str. The input label. **kwargs: The arguments for input functions. Returns: The total currents. """ if label is None: for key, out in self.current_inputs.items(): init = init + out(*args, **kwargs) else: label_repr = self._input_label_start(label) for key, out in self.current_inputs.items(): if key.startswith(label_repr): init = init + out(*args, **kwargs) return init def sum_delta_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs): """Summarize all delta inputs by the defined input functions ``.delta_inputs``. Args: *args: The arguments for input functions. init: The initial input data. label: str. The input label. **kwargs: The arguments for input functions. Returns: The total currents. """ if label is None: for key, out in self.delta_inputs.items(): init = init + out(*args, **kwargs) else: label_repr = self._input_label_start(label) for key, out in self.delta_inputs.items(): if key.startswith(label_repr): init = init + out(*args, **kwargs) return init @classmethod def _input_label_start(cls, label: str): # unify the input label repr. return f'{label} // ' @classmethod def _input_label_repr(cls, name: str, label: Optional[str] = None): # unify the input label repr. return name if label is None else (cls._input_label_start(label) + str(name)) # deprecated # # ---------- # @property def cur_inputs(self): return self.current_inputs def sum_inputs(self, *args, **kwargs): warnings.warn('Please use ".sum_current_inputs()" instead. ".sum_inputs()" will be removed.', UserWarning) return self.sum_current_inputs(*args, **kwargs) class SupportReturnInfo(MixIn): """``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`.""" def return_info(self): raise NotImplementedError('Must implement the "return_info()" function.') class SupportAutoDelay(SupportReturnInfo): pass class SupportOnline(MixIn): """:py:class:`~.MixIn` to support the online training methods. .. versionadded:: 2.4.5 """ online_fit_by: Optional # methods for online fitting def online_init(self, *args, **kwargs): raise NotImplementedError def online_fit(self, target, fit_record: Dict): raise NotImplementedError class SupportOffline(MixIn): """:py:class:`~.MixIn` to support the offline training methods. .. versionadded:: 2.4.5 """ offline_fit_by: Optional # methods for offline fitting def offline_init(self, *args, **kwargs): pass def offline_fit(self, target, fit_record: Dict): raise NotImplementedError class SupportSTDP(MixIn): """Support synaptic plasticity by modifying the weights. """ def stdp_update(self, *args, on_pre=None, onn_post=None, **kwargs): raise NotImplementedError