braintools.surrogate.Surrogate#

class braintools.surrogate.Surrogate#

The base surrogate gradient function.

This abstract base class defines the interface for surrogate gradient functions used in training spiking neural networks. Surrogate gradients replace the non-differentiable spike function with smooth approximations during backpropagation.

To customize a surrogate gradient function, inherit this class and implement the surrogate_fun and surrogate_grad methods.

surrogate_fun(x)[source]#

Defines the smooth surrogate function used for visualization and analysis.

surrogate_grad(x)[source]#

Defines the gradient of the surrogate function used during backpropagation.

Examples

>>> import braintools
>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> # Create a custom surrogate gradient function
>>> class MySurrogate(braintools.surrogate.Surrogate):
...     def __init__(self, alpha=1.):
...         super().__init__()
...         self.alpha = alpha
...
...     def surrogate_fun(self, x):
...         # Define the smooth approximation function
...         return jnp.tanh(x * self.alpha) * 0.5 + 0.5
...
...     def surrogate_grad(self, x):
...         # Define its gradient for backpropagation
...         return self.alpha * 0.5 * (1 - jnp.tanh(x * self.alpha) ** 2)
>>>
>>> # Use the custom surrogate
>>> my_surrogate = MySurrogate(alpha=2.0)
>>> x = jnp.array([-1.0, 0.0, 1.0])
>>> spikes = my_surrogate(x)  # Forward: step function
>>> print(spikes)  # [0., 1., 1.]

Notes

The forward pass always returns a Heaviside step function (0 or 1), while the backward pass uses the custom surrogate gradient defined in surrogate_grad method. This straight-through estimator approach enables gradient-based training of spiking neural networks.

Implementing surrogate_fun is optional – it is used only for visualization/analysis. Gradient-only surrogates (e.g. ReluGrad, GaussianGrad, MultiGaussianGrad, InvSquareGrad, SlayerGrad) do not define it, so calling surrogate_fun on them raises NotImplementedError. When both are defined, surrogate_grad is the exact derivative of surrogate_fun.

The input x is expected to be a dimensionless array (typically the membrane potential minus the threshold, in matching units). Passing a unitful brainunit.Quantity is not supported by the underlying primitive.

Differentiating the output with respect to the input x yields the surrogate gradient, as intended. Differentiating with respect to a surrogate parameter (e.g. alpha) returns the derivative of the surrogate-gradient function w.r.t. that parameter (a side effect of the custom JVP rule), not the mathematically-true 0 of the Heaviside output – so surrogate parameters cannot be trained as ordinary ParamState weights through this output.

__init__()#

Methods

__init__()

surrogate_fun(x)

The surrogate function.

surrogate_grad(x)

The gradient function of the surrogate function.