Source code for brainpy._src.dnn.normalization

# -*- coding: utf-8 -*-

from typing import Union, Optional, Sequence, Callable

from jax import lax, numpy as jnp

from brainpy._src.context import share
from brainpy import math as bm, check
from brainpy.initialize import ZeroInit, OneInit, Initializer, parameter
from brainpy.types import ArrayType
from brainpy._src.dnn.base import Layer

__all__ = [
  'BatchNorm1d',
  'BatchNorm2d',
  'BatchNorm3d',
  'BatchNorm1D',
  'BatchNorm2D',
  'BatchNorm3D',

  'LayerNorm',
  'GroupNorm',
  'InstanceNorm',
]


def _square(x):
  """Computes the elementwise square of the absolute value |x|^2."""
  if jnp.iscomplexobj(x):
    return lax.square(lax.real(x)) + lax.square(lax.imag(x))
  else:
    return lax.square(x)


class BatchNorm(Layer):
  r"""Batch Normalization layer [1]_.

  This layer aims to reduce the internal covariant shift of data. It
  normalizes a batch of data by fixing the mean and variance of inputs
  on each feature (channel). Most commonly, the first axis of the data
  is the batch, and the last is the channel. However, users can specify
  the axes to be normalized.

  .. math::
     y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta

  .. note::
      This :attr:`momentum` argument is different from one used in optimizer
      classes and the conventional notion of momentum. Mathematically, the
      update rule for running statistics here is
      :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`,
      where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
      new observed value.

  Parameters
  ----------
  num_features: int
    ``C`` from an expected input of size ``(..., C)``.
  axis: int, tuple, list
    Axes where the data will be normalized. The feature (channel) axis should be excluded.
  momentum: float
    The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99
  epsilon: float
    A value added to the denominator for numerical stability. Default: 1e-5
  affine: bool
    A boolean value that when set to ``True``, this module has
    learnable affine parameters. Default: ``True``
  bias_initializer: Initializer, ArrayType, Callable
    An initializer generating the original translation matrix
  scale_initializer: Initializer, ArrayType, Callable
    An initializer generating the original scaling matrix
  axis_name: optional, str, sequence of str
    If not ``None``, it should be a string (or sequence of
    strings) representing the axis name(s) over which this module is being
    run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this
    argument means that batch statistics are calculated across all replicas
    on the named axes.
  axis_index_groups: optional, sequence
    Specifies how devices are grouped. Valid
    only within ``jax.pmap`` collectives.

  References
  ----------
  .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag.

  """
  supported_modes = (bm.BatchingMode, bm.TrainingMode)

  def __init__(
      self,
      num_features: int,
      axis: Union[int, Sequence[int]],
      epsilon: float = 1e-5,
      momentum: float = 0.99,
      affine: bool = True,
      bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(),
      scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(),
      axis_name: Optional[Union[str, Sequence[str]]] = None,
      axis_index_groups: Optional[Sequence[Sequence[int]]] = None,
      mode: Optional[bm.Mode] = None,
      name: Optional[str] = None,
  ):
    super(BatchNorm, self).__init__(name=name, mode=mode)
    # check.is_subclass(self.mode, (bm.BatchingMode, bm.TrainingMode), self.name)

    # parameters
    self.num_features = num_features
    self.epsilon = epsilon
    self.momentum = momentum
    self.affine = affine
    self.bias_initializer = bias_initializer
    self.scale_initializer = scale_initializer
    self.axis = (axis,) if jnp.isscalar(axis) else axis
    self.axis_name = axis_name
    self.axis_index_groups = axis_index_groups

    # variables
    self.running_mean = bm.Variable(jnp.zeros(self.num_features))
    self.running_var = bm.Variable(jnp.ones(self.num_features))
    if self.affine:
      assert isinstance(self.mode, bm.TrainingMode)
      self.bias = bm.TrainVar(parameter(self.bias_initializer, self.num_features))
      self.scale = bm.TrainVar(parameter(self.scale_initializer, self.num_features))

  def _check_input_dim(self, x):
    raise NotImplementedError

  def update(self, x):
    self._check_input_dim(x)

    x = bm.as_jax(x)

    if share.load('fit'):
      mean = jnp.mean(x, self.axis)
      mean_of_square = jnp.mean(_square(x), self.axis)
      if self.axis_name is not None:
        mean, mean_of_square = jnp.split(
          lax.pmean(jnp.concatenate([mean, mean_of_square]),
                    axis_name=self.axis_name,
                    axis_index_groups=self.axis_index_groups),
          2
        )
      var = jnp.maximum(0., mean_of_square - _square(mean))
      self.running_mean.value = (self.momentum * self.running_mean + (1 - self.momentum) * mean)
      self.running_var.value = (self.momentum * self.running_var + (1 - self.momentum) * var)
    else:
      mean = self.running_mean.value
      var = self.running_var.value
    stats_shape = [(1 if i in self.axis else x.shape[i]) for i in range(x.ndim)]
    mean = mean.reshape(stats_shape)
    var = var.reshape(stats_shape)

    y = x - mean
    mul = lax.rsqrt(var + lax.convert_element_type(self.epsilon, x.dtype))
    if self.affine:
      mul *= self.scale
    y *= mul
    if self.affine:
      y += self.bias
    return y


[docs] class BatchNorm1d(BatchNorm): r"""1-D batch normalization [1]_. The data should be of `(b, l, c)`, where `b` is the batch dimension, `l` is the layer dimension, and `c` is the channel dimension. .. math:: y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`, where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. Parameters ---------- num_features: int ``C`` from an expected input of size ``(B, L, C)``. axis: int, tuple, list axes where the data will be normalized. The feature (channel) axis should be excluded. epsilon: float A value added to the denominator for numerical stability. Default: 1e-5 momentum: float The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99 affine: bool A boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` bias_initializer: Initializer, ArrayType, Callable an initializer generating the original translation matrix scale_initializer: Initializer, ArrayType, Callable an initializer generating the original scaling matrix axis_name: optional, str, sequence of str If not ``None``, it should be a string (or sequence of strings) representing the axis name(s) over which this module is being run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this argument means that batch statistics are calculated across all replicas on the named axes. axis_index_groups: optional, sequence Specifies how devices are grouped. Valid only within ``jax.pmap`` collectives. References ---------- .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag. """ def __init__( self, num_features: int, axis: Union[int, Sequence[int]] = (0, 1), epsilon: float = 1e-5, momentum: float = 0.99, affine: bool = True, bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), axis_name: Optional[Union[str, Sequence[str]]] = None, axis_index_groups: Optional[Sequence[Sequence[int]]] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super(BatchNorm1d, self).__init__(num_features=num_features, axis=axis, epsilon=epsilon, momentum=momentum, affine=affine, bias_initializer=bias_initializer, scale_initializer=scale_initializer, axis_name=axis_name, axis_index_groups=axis_index_groups, mode=mode, name=name) def _check_input_dim(self, x): if x.ndim != 3: raise ValueError(f"expected 3D input (got {x.ndim}D input)") assert x.shape[-1] == self.num_features
[docs] class BatchNorm2d(BatchNorm): r"""2-D batch normalization [1]_. The data should be of `(b, h, w, c)`, where `b` is the batch dimension, `h` is the height dimension, `w` is the width dimension, and `c` is the channel dimension. .. math:: y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`, where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. Parameters ---------- num_features: int ``C`` from an expected input of size ``(B, H, W, C)``. axis: int, tuple, list axes where the data will be normalized. The feature (channel) axis should be excluded. epsilon: float a value added to the denominator for numerical stability. Default: 1e-5 momentum: float The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99 affine: bool A boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` bias_initializer: Initializer, ArrayType, Callable an initializer generating the original translation matrix scale_initializer: Initializer, ArrayType, Callable an initializer generating the original scaling matrix axis_name: optional, str, sequence of str If not ``None``, it should be a string (or sequence of strings) representing the axis name(s) over which this module is being run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this argument means that batch statistics are calculated across all replicas on the named axes. axis_index_groups: optional, sequence Specifies how devices are grouped. Valid only within ``jax.pmap`` collectives. References ---------- .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag. """ def __init__( self, num_features: int, axis: Union[int, Sequence[int]] = (0, 1, 2), epsilon: float = 1e-5, momentum: float = 0.99, affine: bool = True, bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), axis_name: Optional[Union[str, Sequence[str]]] = None, axis_index_groups: Optional[Sequence[Sequence[int]]] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super(BatchNorm2d, self).__init__(num_features=num_features, axis=axis, epsilon=epsilon, momentum=momentum, affine=affine, bias_initializer=bias_initializer, scale_initializer=scale_initializer, axis_name=axis_name, axis_index_groups=axis_index_groups, mode=mode, name=name) def _check_input_dim(self, x): if x.ndim != 4: raise ValueError(f"expected 4D input (got {x.ndim}D input)") assert x.shape[-1] == self.num_features
[docs] class BatchNorm3d(BatchNorm): r"""3-D batch normalization [1]_. The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension, `h` is the height dimension, `w` is the width dimension, `d` is the depth dimension, and `c` is the channel dimension. .. math:: y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta .. note:: This :attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`, where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. Parameters ---------- num_features: int ``C`` from an expected input of size ``(B, H, W, D, C)``. axis: int, tuple, list axes where the data will be normalized. The feature (channel) axis should be excluded. epsilon: float a value added to the denominator for numerical stability. Default: 1e-5 momentum: float The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99 affine: bool A boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` bias_initializer: Initializer, ArrayType, Callable an initializer generating the original translation matrix scale_initializer: Initializer, ArrayType, Callable an initializer generating the original scaling matrix axis_name: optional, str, sequence of str If not ``None``, it should be a string (or sequence of strings) representing the axis name(s) over which this module is being run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this argument means that batch statistics are calculated across all replicas on the named axes. axis_index_groups: optional, sequence Specifies how devices are grouped. Valid only within ``jax.pmap`` collectives. References ---------- .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag. """ def __init__( self, num_features: int, axis: Union[int, Sequence[int]] = (0, 1, 2, 3), epsilon: float = 1e-5, momentum: float = 0.99, affine: bool = True, bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), axis_name: Optional[Union[str, Sequence[str]]] = None, axis_index_groups: Optional[Sequence[Sequence[int]]] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super(BatchNorm3d, self).__init__(num_features=num_features, axis=axis, epsilon=epsilon, momentum=momentum, affine=affine, bias_initializer=bias_initializer, scale_initializer=scale_initializer, axis_name=axis_name, axis_index_groups=axis_index_groups, mode=mode, name=name) def _check_input_dim(self, x): if x.ndim != 5: raise ValueError(f"expected 5D input (got {x.ndim}D input)") assert x.shape[-1] == self.num_features
[docs] class LayerNorm(Layer): r"""Layer normalization (https://arxiv.org/abs/1607.06450). .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta This layer normalizes data on each example, independently of the batch. More specifically, it normalizes data of shape (b, d1, d2, ..., c) on the axes of the data dimensions and the channel (d1, d2, ..., c). Different from batch normalization, scale and bias are assigned to each position (elementwise operation) instead of the whole channel. If users want to assign a single scale and bias to a whole example/whole channel, please use GroupNorm/ InstanceNorm. Parameters ---------- normalized_shape: int, sequence of int The input shape from an expected input of size .. math:: [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] \times \ldots \times \text{normalized\_shape}[-1]] If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. epsilon: float a value added to the denominator for numerical stability. Default: 1e-5 bias_initializer: Initializer, ArrayType, Callable an initializer generating the original translation matrix scale_initializer: Initializer, ArrayType, Callable an initializer generating the original scaling matrix elementwise_affine: bool A boolean value that when set to ``True``, this module has learnable per-element affine parameters initialized to ones (for weights) and zeros (for biases). Default: ``True``. Examples -------- >>> import brainpy as bp >>> import brainpy.math as bm >>> >>> # NLP Example >>> batch, sentence_length, embedding_dim = 20, 5, 10 >>> embedding = bm.random.randn(batch, sentence_length, embedding_dim) >>> layer_norm = bp.layers.LayerNorm(embedding_dim) >>> # Activate module >>> layer_norm(embedding) >>> >>> # Image Example >>> N, C, H, W = 20, 5, 10, 10 >>> input = bm.random.randn(N, H, W, C) >>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions) >>> # as shown in the image below >>> layer_norm = bp.layers.LayerNorm([H, W, C]) >>> output = layer_norm(input) """ def __init__( self, normalized_shape: Union[int, Sequence[int]], epsilon: float = 1e-5, bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), elementwise_affine: bool = True, mode: Optional[bm.Mode] = None, name: Optional[str] = None ): super(LayerNorm, self).__init__(name=name, mode=mode) self.epsilon = epsilon self.bias_initializer = bias_initializer self.scale_initializer = scale_initializer if isinstance(normalized_shape, int): normalized_shape = (normalized_shape,) self.normalized_shape = tuple(normalized_shape) assert all([isinstance(s, int) for s in normalized_shape]), 'Must be a sequence of integer.' self.elementwise_affine = elementwise_affine if self.elementwise_affine: assert isinstance(self.mode, bm.TrainingMode) self.bias = bm.TrainVar(parameter(self.bias_initializer, self.normalized_shape)) self.scale = bm.TrainVar(parameter(self.scale_initializer, self.normalized_shape))
[docs] def update(self, x): if x.shape[-len(self.normalized_shape):] != self.normalized_shape: raise ValueError(f'Expect the input shape should be (..., {", ".join(self.normalized_shape)}), ' f'but we got {x.shape}') axis = tuple(range(0, x.ndim - len(self.normalized_shape))) mean = jnp.mean(bm.as_jax(x), axis=axis, keepdims=True) variance = jnp.var(bm.as_jax(x), axis=axis, keepdims=True) inv = lax.rsqrt(variance + lax.convert_element_type(self.epsilon, x.dtype)) out = (x - mean) * inv if self.elementwise_affine: out = self.scale * out + self.bias return out
[docs] class GroupNorm(Layer): r"""Group normalization layer. .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta This layer divides channels into groups and normalizes the features within each group. Its computation is also independent of the batch size. The feature size must be multiple of the group size. The shape of the data should be (b, d1, d2, ..., c), where `d` denotes the batch size and `c` denotes the feature (channel) size. Parameters ---------- num_groups: int The number of groups. It should be a factor of the number of channels. num_channels: int The number of channels expected in input. epsilon: float a value added to the denominator for numerical stability. Default: 1e-5 affine: bool A boolean value that when set to ``True``, this module has learnable per-channel affine parameters initialized to ones (for weights) and zeros (for biases). Default: ``True``. bias_initializer: Initializer, ArrayType, Callable An initializer generating the original translation matrix scale_initializer: Initializer, ArrayType, Callable An initializer generating the original scaling matrix Examples -------- >>> import brainpy as bp >>> import brainpy.math as bm >>> input = bm.random.randn(20, 10, 10, 6) >>> # Separate 6 channels into 3 groups >>> m = bp.layers.GroupNorm(3, 6) >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm) >>> m = bp.layers.GroupNorm(6, 6) >>> # Put all 6 channels into a single group (equivalent with LayerNorm) >>> m = bp.layers.GroupNorm(1, 6) >>> # Activating the module >>> output = m(input) """ def __init__( self, num_groups: int, num_channels: int, epsilon: float = 1e-5, affine: bool = True, bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super(GroupNorm, self).__init__(name=name, mode=mode) if num_channels % num_groups != 0: raise ValueError('num_channels must be divisible by num_groups') self.num_groups = num_groups self.num_channels = num_channels self.epsilon = epsilon self.affine = affine self.bias_initializer = bias_initializer self.scale_initializer = scale_initializer if self.affine: assert isinstance(self.mode, bm.TrainingMode) self.bias = bm.TrainVar(parameter(self.bias_initializer, self.num_channels)) self.scale = bm.TrainVar(parameter(self.scale_initializer, self.num_channels))
[docs] def update(self, x): assert x.shape[-1] == self.num_channels origin_shape, origin_dim = x.shape, x.ndim group_shape = (-1,) + x.shape[1:-1] + (self.num_groups, self.num_channels // self.num_groups) x = bm.as_jax(x.reshape(group_shape)) reduction_axes = tuple(range(1, x.ndim - 1)) + (-1,) mean = jnp.mean(x, reduction_axes, keepdims=True) var = jnp.var(x, reduction_axes, keepdims=True) x = (x - mean) * lax.rsqrt(var + lax.convert_element_type(self.epsilon, x.dtype)) x = x.reshape(origin_shape) if self.affine: x = x * lax.broadcast_to_rank(self.scale.value, origin_dim) x = x + lax.broadcast_to_rank(self.bias.value, origin_dim) return x
[docs] class InstanceNorm(GroupNorm): r"""Instance normalization layer. This layer normalizes the data within each feature. It can be regarded as a group normalization layer in which `group_size` equals to 1. Parameters ---------- num_channels: int The number of channels expected in input. epsilon: float a value added to the denominator for numerical stability. Default: 1e-5 affine: bool A boolean value that when set to ``True``, this module has learnable per-channel affine parameters initialized to ones (for weights) and zeros (for biases). Default: ``True``. bias_initializer: Initializer, ArrayType, Callable an initializer generating the original translation matrix scale_initializer: Initializer, ArrayType, Callable an initializer generating the original scaling matrix """ def __init__( self, num_channels: int, epsilon: float = 1e-5, affine: bool = True, bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): super(InstanceNorm, self).__init__(num_channels=num_channels, num_groups=num_channels, epsilon=epsilon, affine=affine, bias_initializer=bias_initializer, scale_initializer=scale_initializer, mode=mode, name=name)
BatchNorm1D = BatchNorm1d BatchNorm2D = BatchNorm2d BatchNorm3D = BatchNorm3d