braintools.surrogate.Surrogate#
- class braintools.surrogate.Surrogate#
The base surrogate gradient function.
This abstract base class defines the interface for surrogate gradient functions used in training spiking neural networks. Surrogate gradients replace the non-differentiable spike function with smooth approximations during backpropagation.
To customize a surrogate gradient function, inherit this class and implement the surrogate_fun and surrogate_grad methods.
- surrogate_fun(x)[source]#
Defines the smooth surrogate function used for visualization and analysis.
- surrogate_grad(x)[source]#
Defines the gradient of the surrogate function used during backpropagation.
Examples
>>> import braintools >>> import brainstate >>> import jax.numpy as jnp >>> >>> # Create a custom surrogate gradient function >>> class MySurrogate(braintools.surrogate.Surrogate): ... def __init__(self, alpha=1.): ... super().__init__() ... self.alpha = alpha ... ... def surrogate_fun(self, x): ... # Define the smooth approximation function ... return jnp.tanh(x * self.alpha) * 0.5 + 0.5 ... ... def surrogate_grad(self, x): ... # Define its gradient for backpropagation ... return self.alpha * 0.5 * (1 - jnp.tanh(x * self.alpha) ** 2) >>> >>> # Use the custom surrogate >>> my_surrogate = MySurrogate(alpha=2.0) >>> x = jnp.array([-1.0, 0.0, 1.0]) >>> spikes = my_surrogate(x) # Forward: step function >>> print(spikes) # [0., 1., 1.]
Notes
The forward pass always returns a Heaviside step function (0 or 1), while the backward pass uses the custom surrogate gradient defined in surrogate_grad method. This straight-through estimator approach enables gradient-based training of spiking neural networks.
Implementing
surrogate_funis optional – it is used only for visualization/analysis. Gradient-only surrogates (e.g.ReluGrad,GaussianGrad,MultiGaussianGrad,InvSquareGrad,SlayerGrad) do not define it, so callingsurrogate_funon them raisesNotImplementedError. When both are defined,surrogate_gradis the exact derivative ofsurrogate_fun.The input
xis expected to be a dimensionless array (typically the membrane potential minus the threshold, in matching units). Passing a unitfulbrainunit.Quantityis not supported by the underlying primitive.Differentiating the output with respect to the input
xyields the surrogate gradient, as intended. Differentiating with respect to a surrogate parameter (e.g.alpha) returns the derivative of the surrogate-gradient function w.r.t. that parameter (a side effect of the custom JVP rule), not the mathematically-true0of the Heaviside output – so surrogate parameters cannot be trained as ordinaryParamStateweights through this output.- __init__()#
Methods
__init__()The surrogate function.
The gradient function of the surrogate function.