brainpy.math.surrogate.Surrogate

brainpy.math.surrogate.Surrogate#

class brainpy.math.surrogate.Surrogate[source]#

The base surrograte gradient function.

To customize a surrogate gradient function, you can inherit this class and implement the surrogate_fun and surrogate_grad methods.

Examples

>>> import brainpy as bp
>>> import brainpy.math as bm
>>> import jax.numpy as jnp
>>> class MySurrogate(bm.Surrogate):
...   def __init__(self, alpha=1.):
...     super().__init__()
...     self.alpha = alpha
...
...   def surrogate_fun(self, x):
...     return jnp.sin(x) * self.alpha
...
...   def surrogate_grad(self, x):
...     return jnp.cos(x) * self.alpha
__init__()#

Methods

__init__()

surrogate_fun(x)

The surrogate function.

surrogate_grad(x)

The gradient function of the surrogate function.