BatchNorm2d#
- class brainpy.dnn.BatchNorm2d(num_features, axis=(0, 1, 2), epsilon=1e-05, momentum=0.99, affine=True, bias_initializer=ZeroInit, scale_initializer=OneInit(value=1.0), axis_name=None, axis_index_groups=None, mode=None, name=None)[source]#
2-D batch normalization [1].
The data should be of (b, h, w, c), where b is the batch dimension, h is the height dimension, w is the width dimension, and c is the channel dimension.
\[y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta\]Note
This
momentumargument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is \(\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t\), where \(\hat{x}\) is the estimated statistic and \(x_t\) is the new observed value.- Parameters:
num_features (
int) –Cfrom an expected input of size(B, H, W, C).axis (
Union[int,Sequence[int]]) – axes where the data will be normalized. The feature (channel) axis should be excluded.epsilon (
float) – a value added to the denominator for numerical stability. Default: 1e-5momentum (
float) – The value used for therunning_meanandrunning_varcomputation. Default: 0.99affine (
bool) – A boolean value that when set toTrue, this module has learnable affine parameters. Default:Truebias_initializer (
Union[Initializer,TypeVar(ArrayType,Array,Variable,TrainVar,Array,ndarray),Callable]) – an initializer generating the original translation matrixscale_initializer (
Union[Initializer,TypeVar(ArrayType,Array,Variable,TrainVar,Array,ndarray),Callable]) – an initializer generating the original scaling matrixaxis_name (
Union[str,Sequence[str],None]) – If notNone, it should be a string (or sequence of strings) representing the axis name(s) over which this module is being run within a jax map (e.g.jax.pmaporjax.vmap). Supplying this argument means that batch statistics are calculated across all replicas on the named axes.axis_index_groups (
Optional[Sequence[Sequence[int]]]) – Specifies how devices are grouped. Valid only withinjax.pmapcollectives.
References