Source code for braintools.surrogate._base

# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import jax
import jax.numpy as jnp
from brainstate._compatible_import import Primitive
from brainstate.util import PrettyObject
from jax.interpreters import batching, ad, mlir

__all__ = ['Surrogate']


def _heaviside_abstract(x, dx):
    return [x]


def _heaviside_imp(x, dx):
    z = jnp.asarray(x >= 0, dtype=x.dtype)
    return [z]


def _heaviside_batching(args, axes):
    x, dx = args
    x_axis, dx_axis = axes

    # Handle case where both are batched but on different axes
    if x_axis is not None and dx_axis is not None and x_axis != dx_axis:
        dx = jnp.moveaxis(dx, dx_axis, x_axis)
        out_axis = x_axis
    elif x_axis is not None:
        out_axis = x_axis
    elif dx_axis is not None:
        out_axis = dx_axis
        x = jnp.repeat(jnp.expand_dims(x, axis=dx_axis), axis=dx_axis, repeats=dx.shape[dx_axis])
    else:
        out_axis = None

    # Since heaviside_p.multiple_results = True, bind returns a tuple
    # and we need to return (result_tuple, axes_tuple)
    result = heaviside_p.bind(x, dx)
    return result, (out_axis,)


def _heaviside_jvp(primals, tangents):
    x, dx = primals
    tx, tdx = tangents
    # Call the implementation directly instead of bind to avoid recursion
    primal_outs = _heaviside_imp(x, dx)
    # Handle gradients w.r.t. both x and dx
    # ∂output/∂x via surrogate gradient dx, plus ∂output/∂dx contribution
    # Need to handle JAX's Zero type for optimization
    if type(tx) is ad.Zero:
        tangent_x = tx  # Keep as Zero
    else:
        tangent_x = dx * tx

    if type(tdx) is ad.Zero:
        tangent_dx = tdx  # Keep as Zero
    else:
        tangent_dx = tdx

    # Combine tangents using add_tangents which handles Zero properly
    tangent_outs = [ad.add_tangents(tangent_x, tangent_dx)]
    return primal_outs, tangent_outs


def _heaviside_transpose(ct, x, dx):
    """
    Transpose rule for reverse-mode autodiff.

    This computes cotangents for the tangents (tx, tdx) given the output cotangent.

    From JVP: output_tangent = dx * tx + tdx
    Transpose: cotangent_tx = dx * ct_out, cotangent_tdx = ct_out
    """
    # ct is a tuple/list containing the cotangent for each output
    ct_out = ct[0]

    # Cotangent for tx (from dx * tx term)
    # In JAX transpose, dx is a residual and might be UndefinedPrimal if it's symbolic
    if type(dx) is ad.UndefinedPrimal:
        # Can't use dx if it's undefined - return zero
        cotangent_tx = ad.Zero(dx.aval)
    else:
        cotangent_tx = dx * ct_out

    # Cotangent for tdx (from tdx term)
    cotangent_tdx = ct_out

    return (cotangent_tx, cotangent_tdx)


heaviside_p = Primitive('heaviside_surrogate_gradient')
heaviside_p.multiple_results = True
heaviside_p.def_abstract_eval(_heaviside_abstract)
heaviside_p.def_impl(_heaviside_imp)
batching.primitive_batchers[heaviside_p] = _heaviside_batching
ad.primitive_jvps[heaviside_p] = _heaviside_jvp
# Let JAX automatically derive transpose from JVP
# ad.primitive_transposes[heaviside_p] = _heaviside_transpose
mlir.register_lowering(heaviside_p, mlir.lower_fun(_heaviside_imp, multiple_results=True))


class Surrogate(PrettyObject):
    r"""The base surrogate gradient function.

    This abstract base class defines the interface for surrogate gradient functions
    used in training spiking neural networks. Surrogate gradients replace the
    non-differentiable spike function with smooth approximations during backpropagation.

    To customize a surrogate gradient function, inherit this class and implement
    the `surrogate_fun` and `surrogate_grad` methods.

    Methods
    -------
    surrogate_fun(x)
        Defines the smooth surrogate function used for visualization and analysis.
    surrogate_grad(x)
        Defines the gradient of the surrogate function used during backpropagation.

    Examples
    --------
    .. code-block:: python

        >>> import braintools
        >>> import brainstate
        >>> import jax.numpy as jnp
        >>>
        >>> # Create a custom surrogate gradient function
        >>> class MySurrogate(braintools.surrogate.Surrogate):
        ...     def __init__(self, alpha=1.):
        ...         super().__init__()
        ...         self.alpha = alpha
        ...
        ...     def surrogate_fun(self, x):
        ...         # Define the smooth approximation function
        ...         return jnp.tanh(x * self.alpha) * 0.5 + 0.5
        ...
        ...     def surrogate_grad(self, x):
        ...         # Define its gradient for backpropagation
        ...         return self.alpha * 0.5 * (1 - jnp.tanh(x * self.alpha) ** 2)
        >>>
        >>> # Use the custom surrogate
        >>> my_surrogate = MySurrogate(alpha=2.0)
        >>> x = jnp.array([-1.0, 0.0, 1.0])
        >>> spikes = my_surrogate(x)  # Forward: step function
        >>> print(spikes)  # [0., 1., 1.]

    Notes
    -----
    The forward pass always returns a Heaviside step function (0 or 1),
    while the backward pass uses the custom surrogate gradient defined in
    `surrogate_grad` method. This straight-through estimator approach enables
    gradient-based training of spiking neural networks.

    Implementing ``surrogate_fun`` is **optional** -- it is used only for
    visualization/analysis. Gradient-only surrogates (e.g. ``ReluGrad``,
    ``GaussianGrad``, ``MultiGaussianGrad``, ``InvSquareGrad``, ``SlayerGrad``)
    do not define it, so calling ``surrogate_fun`` on them raises
    ``NotImplementedError``. When both are defined, ``surrogate_grad`` is the
    exact derivative of ``surrogate_fun``.

    The input ``x`` is expected to be a **dimensionless** array (typically the
    membrane potential minus the threshold, in matching units). Passing a
    unitful ``brainunit.Quantity`` is not supported by the underlying primitive.

    Differentiating the output with respect to the **input** ``x`` yields the
    surrogate gradient, as intended. Differentiating with respect to a
    **surrogate parameter** (e.g. ``alpha``) returns the derivative of the
    surrogate-gradient function w.r.t. that parameter (a side effect of the
    custom JVP rule), *not* the mathematically-true ``0`` of the Heaviside
    output -- so surrogate parameters cannot be trained as ordinary
    ``ParamState`` weights through this output.

    """
    __module__ = 'braintools.surrogate'

    def __call__(self, x):
        dx = self.surrogate_grad(jax.lax.stop_gradient(x))
        return heaviside_p.bind(x, dx)[0]

[docs] def surrogate_fun(self, x) -> jax.Array: """The surrogate function.""" raise NotImplementedError
[docs] def surrogate_grad(self, x) -> jax.Array: """The gradient function of the surrogate function.""" raise NotImplementedError