Source code for brainpy.math.pre_syn_post

# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import jax.numpy as jnp
from jax import vmap, jit, ops as jops

from brainpy._errors import MathError
from brainpy.math import event
from brainpy.math.interoperability import as_jax

__all__ = [
    # pre-to-post
    'pre2post_sum',
    'pre2post_prod',
    'pre2post_max',
    'pre2post_min',
    'pre2post_mean',

    # pre-to-post event operator
    'pre2post_event_sum',
    'pre2post_csr_event_sum',

    # pre-to-syn
    'pre2syn',

    # syn-to-post
    'syn2post_sum', 'syn2post',
    'syn2post_prod',
    'syn2post_max',
    'syn2post_min',
    'syn2post_mean',
    'syn2post_softmax',
]


def _raise_pre_ids_is_none(pre_ids):
    if pre_ids is None:
        raise MathError(f'pre2post synaptic computation needs "pre_ids" '
                        f'when providing heterogeneous "pre_values" '
                        f'(brainpy.math.ndim(pre_values) != 0).')


[docs] def pre2post_event_sum(events, pre2post, post_num: int, values=1.): """The pre-to-post event-driven synaptic summation with `CSR` synapse structure. When ``values`` is a scalar, this function is equivalent to .. highlight:: python .. code-block:: python post_val = np.zeros(post_num) post_ids, idnptr = pre2post for i in range(pre_num): if events[i]: for j in range(idnptr[i], idnptr[i+1]): post_val[post_ids[j]] += values When ``values`` is a vector (with the length of ``len(post_ids)``), this function is equivalent to .. highlight:: python .. code-block:: python post_val = np.zeros(post_num) post_ids, idnptr = pre2post for i in range(pre_num): if events[i]: for j in range(idnptr[i], idnptr[i+1]): post_val[post_ids[j]] += values[j] Parameters ---------- events : ArrayType The events, must be bool. pre2post : tuple of ArrayType, tuple of ArrayType A tuple contains the connection information of pre-to-post. post_num : int The number of post-synaptic group. values : float, ArrayType The value to make summation. Returns ------- out : ArrayType A tensor with the shape of ``post_num``. """ indices, idnptr = pre2post events = as_jax(events) indices = as_jax(indices) idnptr = as_jax(idnptr) values = as_jax(values) return event.csrmv(values, indices, idnptr, events, shape=(events.shape[0], post_num), transpose=True)
pre2post_csr_event_sum = pre2post_event_sum
[docs] def pre2post_sum(pre_values, post_num, post_ids, pre_ids=None): """The pre-to-post synaptic summation. This function is equivalent to: .. highlight:: python .. code-block:: python post_val = np.zeros(post_num) for i, j in zip(pre_ids, post_ids): post_val[j] += pre_values[pre_ids[i]] Parameters ---------- pre_values : float, ArrayType The pre-synaptic values. post_ids : ArrayType The connected post-synaptic neuron ids. post_num : int Output dimension. The number of post-synaptic neurons. pre_ids : optional, ArrayType The connected pre-synaptic neuron ids. Returns ------- post_val : ArrayType The value with the size of post-synaptic neurons. """ out = jnp.zeros(post_num) pre_values = as_jax(pre_values) post_ids = as_jax(post_ids) if jnp.ndim(pre_values) != 0: _raise_pre_ids_is_none(pre_ids) pre_ids = as_jax(pre_ids) pre_values = pre_values[pre_ids] return out.at[post_ids].add(pre_values)
[docs] def pre2post_prod(pre_values, post_num, post_ids, pre_ids=None): """The pre-to-post synaptic production. This function is equivalent to: .. highlight:: python .. code-block:: python post_val = np.zeros(post_num) for i, j in zip(pre_ids, post_ids): post_val[j] *= pre_values[pre_ids[i]] Parameters ---------- pre_values : float, ArrayType The pre-synaptic values. pre_ids : ArrayType The connected pre-synaptic neuron ids. post_ids : ArrayType The connected post-synaptic neuron ids. post_num : int Output dimension. The number of post-synaptic neurons. Returns ------- post_val : ArrayType The value with the size of post-synaptic neurons. """ out = jnp.zeros(post_num) pre_values = as_jax(pre_values) post_ids = as_jax(post_ids) if jnp.ndim(pre_values) != 0: _raise_pre_ids_is_none(pre_ids) pre_ids = as_jax(pre_ids) pre_values = pre_values[pre_ids] return out.at[post_ids].multiply(pre_values)
[docs] def pre2post_min(pre_values, post_num, post_ids, pre_ids=None): """The pre-to-post synaptic minimization. This function is equivalent to: .. highlight:: python .. code-block:: python post_val = np.zeros(post_num) for i, j in zip(pre_ids, post_ids): post_val[j] = np.minimum(post_val[j], pre_values[pre_ids[i]]) Parameters ---------- pre_values : float, ArrayType The pre-synaptic values. pre_ids : ArrayType The connected pre-synaptic neuron ids. post_ids : ArrayType The connected post-synaptic neuron ids. post_num : int Output dimension. The number of post-synaptic neurons. Returns ------- post_val : ArrayType The value with the size of post-synaptic neurons. """ out = jnp.zeros(post_num) pre_values = as_jax(pre_values) post_ids = as_jax(post_ids) if jnp.ndim(pre_values) != 0: _raise_pre_ids_is_none(pre_ids) pre_ids = as_jax(pre_ids) pre_values = pre_values[pre_ids] return out.at[post_ids].min(pre_values)
[docs] def pre2post_max(pre_values, post_num, post_ids, pre_ids=None): """The pre-to-post synaptic maximization. This function is equivalent to: .. highlight:: python .. code-block:: python post_val = np.zeros(post_num) for i, j in zip(pre_ids, post_ids): post_val[j] = np.maximum(post_val[j], pre_values[pre_ids[i]]) Parameters ---------- pre_values : float, ArrayType The pre-synaptic values. pre_ids : ArrayType The connected pre-synaptic neuron ids. post_ids : ArrayType The connected post-synaptic neuron ids. post_num : int Output dimension. The number of post-synaptic neurons. Returns ------- post_val : ArrayType The value with the size of post-synaptic neurons. """ out = jnp.zeros(post_num) pre_values = as_jax(pre_values) post_ids = as_jax(post_ids) if jnp.ndim(pre_values) != 0: _raise_pre_ids_is_none(pre_ids) pre_ids = as_jax(pre_ids) pre_values = pre_values[pre_ids] return out.at[post_ids].max(pre_values)
[docs] def pre2post_mean(pre_values, post_num, post_ids, pre_ids=None): """The pre-to-post synaptic mean computation. Parameters ---------- pre_values : float, ArrayType The pre-synaptic values. pre_ids : ArrayType The connected pre-synaptic neuron ids. post_ids : ArrayType The connected post-synaptic neuron ids. post_num : int Output dimension. The number of post-synaptic neurons. Returns ------- post_val : ArrayType The value with the size of post-synaptic neurons. Notes ----- When ``pre_values`` is a scalar, every connection carries the same constant value, so the per-post mean is simply that constant. In this case the function broadcasts the constant to every targeted post-synaptic neuron (untargeted neurons stay ``0``). Duplicate ``post_ids`` therefore do not require any averaging -- the mean of identical values equals the value itself. """ out = jnp.zeros(post_num) pre_values = as_jax(pre_values) post_ids = as_jax(post_ids) if jnp.ndim(pre_values) == 0: # Scalar branch: every synapse carries the same constant ``pre_values``, # so the mean over any group of post-synaptic targets is that constant. # Broadcast it to the targeted posts (duplicate ``post_ids`` are harmless # because the mean of identical values is the value itself). return out.at[post_ids].set(pre_values) else: _raise_pre_ids_is_none(pre_ids) pre_ids = as_jax(pre_ids) pre_values = pre2syn(pre_values, pre_ids) return syn2post_mean(pre_values, post_ids, post_num)
_pre2syn = vmap(lambda pre_id, pre_vs: pre_vs[pre_id], in_axes=(0, None))
[docs] def pre2syn(pre_values, pre_ids): """The pre-to-syn computation. Change the pre-synaptic data to the data with the dimension of synapses. This function is equivalent to: .. highlight:: python .. code-block:: python syn_val = np.zeros(len(pre_ids)) for syn_i, pre_i in enumerate(pre_ids): syn_val[i] = pre_values[pre_i] Parameters ---------- pre_values : float, ArrayType The pre-synaptic value. pre_ids : ArrayType The pre-synaptic neuron index. Returns ------- syn_val : ArrayType The synaptic value. """ pre_values = as_jax(pre_values) pre_ids = as_jax(pre_ids) if jnp.ndim(pre_values) == 0: return jnp.ones(len(pre_ids), dtype=pre_values.dtype) * pre_values else: return _pre2syn(pre_ids, pre_values)
_jit_seg_sum = jit(jops.segment_sum, static_argnums=(2, 3)) _jit_seg_prod = jit(jops.segment_prod, static_argnums=(2, 3)) _jit_seg_max = jit(jops.segment_max, static_argnums=(2, 3)) _jit_seg_min = jit(jops.segment_min, static_argnums=(2, 3))
[docs] def syn2post_sum(syn_values, post_ids, post_num: int, indices_are_sorted=False): """The syn-to-post summation computation. This function is equivalent to: .. highlight:: python .. code-block:: python post_val = np.zeros(post_num) for syn_i, post_i in enumerate(post_ids): post_val[post_i] += syn_values[syn_i] Parameters ---------- syn_values : ArrayType The synaptic values. post_ids : ArrayType The post-synaptic neuron ids. post_num : int The number of the post-synaptic neurons. Returns ------- post_val : ArrayType The post-synaptic value. """ post_ids = as_jax(post_ids) syn_values = as_jax(syn_values) if syn_values.dtype == jnp.bool_: syn_values = jnp.asarray(syn_values, dtype=jnp.int32) return _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted)
syn2post = syn2post_sum
[docs] def syn2post_prod(syn_values, post_ids, post_num: int, indices_are_sorted=False): """The syn-to-post product computation. This function is equivalent to: .. highlight:: python .. code-block:: python post_val = np.zeros(post_num) for syn_i, post_i in enumerate(post_ids): post_val[post_i] *= syn_values[syn_i] Parameters ---------- syn_values : ArrayType The synaptic values. post_ids : ArrayType The post-synaptic neuron ids. If ``post_ids`` is generated by ``brainpy.conn.TwoEndConnector``, then it has sorted indices. Otherwise, this function cannot guarantee indices are sorted. You's better set ``indices_are_sorted=False``. post_num : int The number of the post-synaptic neurons. indices_are_sorted : whether ``post_ids`` is known to be sorted. Returns ------- post_val : ArrayType The post-synaptic value. """ post_ids = as_jax(post_ids) syn_values = as_jax(syn_values) if syn_values.dtype == jnp.bool_: syn_values = jnp.asarray(syn_values, dtype=jnp.int32) return _jit_seg_prod(syn_values, post_ids, post_num, indices_are_sorted)
[docs] def syn2post_max(syn_values, post_ids, post_num: int, indices_are_sorted=False): """The syn-to-post maximum computation. This function is equivalent to: .. highlight:: python .. code-block:: python post_val = np.zeros(post_num) for syn_i, post_i in enumerate(post_ids): post_val[post_i] = np.maximum(post_val[post_i], syn_values[syn_i]) Parameters ---------- syn_values : ArrayType The synaptic values. post_ids : ArrayType The post-synaptic neuron ids. If ``post_ids`` is generated by ``brainpy.conn.TwoEndConnector``, then it has sorted indices. Otherwise, this function cannot guarantee indices are sorted. You's better set ``indices_are_sorted=False``. post_num : int The number of the post-synaptic neurons. indices_are_sorted : whether ``post_ids`` is known to be sorted. Returns ------- post_val : ArrayType The post-synaptic value. """ post_ids = as_jax(post_ids) syn_values = as_jax(syn_values) if syn_values.dtype == jnp.bool_: syn_values = jnp.asarray(syn_values, dtype=jnp.int32) return _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted)
[docs] def syn2post_min(syn_values, post_ids, post_num: int, indices_are_sorted=False): """The syn-to-post minimization computation. This function is equivalent to: .. highlight:: python .. code-block:: python post_val = np.zeros(post_num) for syn_i, post_i in enumerate(post_ids): post_val[post_i] = np.minimum(post_val[post_i], syn_values[syn_i]) Parameters ---------- syn_values : ArrayType The synaptic values. post_ids : ArrayType The post-synaptic neuron ids. If ``post_ids`` is generated by ``brainpy.conn.TwoEndConnector``, then it has sorted indices. Otherwise, this function cannot guarantee indices are sorted. You's better set ``indices_are_sorted=False``. post_num : int The number of the post-synaptic neurons. indices_are_sorted : whether ``post_ids`` is known to be sorted. Returns ------- post_val : ArrayType The post-synaptic value. """ post_ids = as_jax(post_ids) syn_values = as_jax(syn_values) if syn_values.dtype == jnp.bool_: syn_values = jnp.asarray(syn_values, dtype=jnp.int32) return _jit_seg_min(syn_values, post_ids, post_num, indices_are_sorted)
[docs] def syn2post_mean(syn_values, post_ids, post_num: int, indices_are_sorted=False): """The syn-to-post mean computation. Parameters ---------- syn_values : ArrayType The synaptic values. post_ids : ArrayType The post-synaptic neuron ids. If ``post_ids`` is generated by ``brainpy.conn.TwoEndConnector``, then it has sorted indices. Otherwise, this function cannot guarantee indices are sorted. You's better set ``indices_are_sorted=False``. post_num : int The number of the post-synaptic neurons. indices_are_sorted : whether ``post_ids`` is known to be sorted. Returns ------- post_val : ArrayType The post-synaptic value. """ post_ids = as_jax(post_ids) syn_values = as_jax(syn_values) if syn_values.dtype == jnp.bool_: syn_values = jnp.asarray(syn_values, dtype=jnp.int32) nominator = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) denominator = _jit_seg_sum(jnp.ones_like(syn_values), post_ids, post_num, indices_are_sorted) # Guard only the empty-group case (denominator == 0) instead of masking with # ``nan_to_num``, which would also silently hide genuine NaNs in ``syn_values``. return jnp.where(denominator > 0, nominator / denominator, 0.)
[docs] def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=False): """The syn-to-post softmax computation. Parameters ---------- syn_values : ArrayType The synaptic values. post_ids : ArrayType The post-synaptic neuron ids. If ``post_ids`` is generated by ``brainpy.conn.TwoEndConnector``, then it has sorted indices. Otherwise, this function cannot guarantee indices are sorted. You's better set ``indices_are_sorted=False``. post_num : int The number of the post-synaptic neurons. indices_are_sorted : whether ``post_ids`` is known to be sorted. Returns ------- post_val : ArrayType The post-synaptic value. """ post_ids = as_jax(post_ids) syn_values = as_jax(syn_values) if syn_values.dtype == jnp.bool_: syn_values = jnp.asarray(syn_values, dtype=jnp.int32) syn_maxs = _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) syn_values = syn_values - syn_maxs[post_ids] syn_values = jnp.exp(syn_values) normalizers = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) # ``normalizers[post_ids]`` is structurally >= 1 for every referenced post group # (each such group contains at least the current synapse, contributing # ``exp(0) == 1`` after the max-subtraction), so this division never hits a # genuine 0/0. The previous ``jnp.nan_to_num`` only served to silently hide # genuine NaNs produced upstream (e.g. NaNs already present in ``syn_values``), # so it is intentionally removed to let such NaNs propagate. return syn_values / normalizers[post_ids]