BatchNorm1d#
- class brainpy.dnn.BatchNorm1d(num_features, axis=(0, 1), epsilon=1e-05, momentum=0.99, affine=True, bias_initializer=ZeroInit, scale_initializer=OneInit(value=1.0), axis_name=None, axis_index_groups=None, mode=None, name=None)[source]#
1-D batch normalization [1].
The data should be of (b, l, c), where b is the batch dimension, l is the layer dimension, and c is the channel dimension.
\[y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta\]Note
This
momentum
argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is \(\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t\), where \(\hat{x}\) is the estimated statistic and \(x_t\) is the new observed value.- Parameters:
num_features (int) –
C
from an expected input of size(B, L, C)
.axis (int, tuple, list) – axes where the data will be normalized. The feature (channel) axis should be excluded.
epsilon (float) – A value added to the denominator for numerical stability. Default: 1e-5
momentum (float) – The value used for the
running_mean
andrunning_var
computation. Default: 0.99affine (bool) – A boolean value that when set to
True
, this module has learnable affine parameters. Default:True
bias_initializer (Initializer, ArrayType, Callable) – an initializer generating the original translation matrix
scale_initializer (Initializer, ArrayType, Callable) – an initializer generating the original scaling matrix
axis_name (optional, str, sequence of str) – If not
None
, it should be a string (or sequence of strings) representing the axis name(s) over which this module is being run within a jax map (e.g.jax.pmap
orjax.vmap
). Supplying this argument means that batch statistics are calculated across all replicas on the named axes.axis_index_groups (optional, sequence) – Specifies how devices are grouped. Valid only within
jax.pmap
collectives.
References