BatchNorm3d

Contents

BatchNorm3d#

class brainpy.dnn.BatchNorm3d(num_features, axis=(0, 1, 2, 3), epsilon=1e-05, momentum=0.99, affine=True, bias_initializer=ZeroInit, scale_initializer=OneInit(value=1.0), axis_name=None, axis_index_groups=None, mode=None, name=None)[source]#

3-D batch normalization [1].

The data should be of (b, h, w, d, c), where b is the batch dimension, h is the height dimension, w is the width dimension, d is the depth dimension, and c is the channel dimension.

\[y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta\]

Note

This momentum argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is \(\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t\), where \(\hat{x}\) is the estimated statistic and \(x_t\) is the new observed value.

Parameters:
  • num_features (int) – C from an expected input of size (B, H, W, D, C).

  • axis (int, tuple, list) – axes where the data will be normalized. The feature (channel) axis should be excluded.

  • epsilon (float) – a value added to the denominator for numerical stability. Default: 1e-5

  • momentum (float) – The value used for the running_mean and running_var computation. Default: 0.99

  • affine (bool) – A boolean value that when set to True, this module has learnable affine parameters. Default: True

  • bias_initializer (Initializer, ArrayType, Callable) – an initializer generating the original translation matrix

  • scale_initializer (Initializer, ArrayType, Callable) – an initializer generating the original scaling matrix

  • axis_name (optional, str, sequence of str) – If not None, it should be a string (or sequence of strings) representing the axis name(s) over which this module is being run within a jax map (e.g. jax.pmap or jax.vmap). Supplying this argument means that batch statistics are calculated across all replicas on the named axes.

  • axis_index_groups (optional, sequence) – Specifies how devices are grouped. Valid only within jax.pmap collectives.

References