Unflatten

Unflatten#

class brainpy.dnn.Unflatten(dim, sizes, mode=None, name=None)[source]#

Unflattens a tensor dim expanding it to a desired shape. For use with Sequential.

  • dim specifies the dimension of the input tensor to be unflattened, and it can be either int or str when Tensor or NamedTensor is used, respectively.

  • unflattened_size is the new shape of the unflattened dimension of the tensor and it can be a tuple of ints or a list of ints or torch.Size for Tensor input; a NamedShape (tuple of (name, size) tuples) for NamedTensor input.

Shape:
  • Input: \((*, S_{\text{dim}}, *)\), where \(S_{\text{dim}}\) is the size at dimension dim and \(*\) means any number of dimensions including none.

  • Output: \((*, U_1, ..., U_n, *)\), where \(U\) = unflattened_size and \(\prod_{i=1}^n U_i = S_{\text{dim}}\).

Parameters:
  • dim (int) – int, Dimension to be unflattened.

  • sizes (Sequence[int]) – Sequence of int. New shape of the unflattened dimension.

Examples

>>> import brainpy as bp
>>> import brainpy.math as bm
>>> input = bm.random.randn(2, 50)
>>> # With tuple of ints
>>> m = bp.Sequential(
>>>     bp.dnn.Linear(50, 50),
>>>     Unflatten(1, (2, 5, 5))
>>> )
>>> output = m(input)
>>> output.shape
(2, 2, 5, 5)
>>> # With torch.Size
>>> m = bp.Sequential(
>>>     bp.dnn.Linear(50, 50),
>>>     Unflatten(1, [2, 5, 5])
>>> )
>>> output = m(input)
>>> output.shape
(2, 2, 5, 5)
update(x)[source]#

The function to specify the updating rule.