Source code for brainpy.dnn.pooling

# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Union, Tuple, Sequence, Optional, Callable, List, Any

import jax
import jax.numpy as jnp
import numpy as np

from brainpy import math as bm, check
from brainpy.dnn.base import Layer

__all__ = [
    'MaxPool',
    'MinPool',
    'AvgPool',
    'AvgPool1d',
    'AvgPool2d',
    'AvgPool3d',
    'MaxPool1d',
    'MaxPool2d',
    'MaxPool3d',
    'AdaptiveAvgPool1d',
    'AdaptiveAvgPool2d',
    'AdaptiveAvgPool3d',
    'AdaptiveMaxPool1d',
    'AdaptiveMaxPool2d',
    'AdaptiveMaxPool3d',
]


class Pool(Layer):
    """Pooling functions are implemented using the ReduceWindow XLA op.

    Parameters::

    kernel_size: int, sequence of int
      An integer, or a sequence of integers defining the window to reduce over.
    stride: int, sequence of int
      An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
    padding: str, sequence of tuple
      Either the string `'SAME'`, the string `'VALID'`, or a sequence
      of n `(low, high)` integer pairs that give the padding to apply before
      and after each spatial dimension.
    channel_axis: int, optional
      Axis of the spatial channels for which pooling is skipped,
      used to infer ``kernel_size`` or ``stride`` if they are an integer.
    mode: Mode
      The computation mode.
    name: optional, str
      The object name.

    """

    def __init__(
        self,
        init_value,
        computation,
        kernel_size: Union[int, Sequence[int]],
        stride: Union[int, Sequence[int]],
        padding: Union[str, Sequence[Tuple[int, int]]] = "VALID",
        channel_axis: Optional[int] = None,
        mode: bm.Mode = None,
        name: Optional[str] = None,
    ):
        super(Pool, self).__init__(mode=mode, name=name)

        self.init_value = init_value
        self.computation = computation
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.channel_axis = channel_axis
        if isinstance(padding, str):
            if padding not in ("SAME", "VALID"):
                raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
        else:
            assert all([isinstance(x, (tuple, list)) for x in padding]), \
                f'padding should be sequence of Tuple[int, int]. {padding}'
            assert all([len(x) == 2 for x in padding]), f"each entry in padding {padding} must be length 2"

    def update(self, x):
        x = bm.as_jax(x)
        window_shape = self._infer_shape(x.ndim, self.kernel_size)
        stride = self._infer_shape(x.ndim, self.stride)
        padding = (self.padding
                   if isinstance(self.padding, str) else
                   self._infer_shape(x.ndim, self.padding, element=(0, 0), element_type=(tuple, list)))
        r = jax.lax.reduce_window(bm.as_jax(x),
                                  init_value=self.init_value,
                                  computation=self.computation,
                                  window_dimensions=window_shape,
                                  window_strides=stride,
                                  padding=padding)
        return r

    def _infer_shape(self,
                     x_dim: int,
                     size: Union[Any, Sequence[Any]],
                     element: Any = 1,
                     element_type: Union[type, Sequence[type]] = int):
        """Infer shape for pooling window or stride."""

        # channel axis
        channel_axis = self.channel_axis
        if channel_axis and not 0 <= abs(channel_axis) < x_dim:
            raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
        if channel_axis and channel_axis < 0:
            channel_axis = x_dim + channel_axis

        if isinstance(size, (tuple, list)) and isinstance(size[0], element_type):
            size = tuple(size)
            if len(size) > x_dim:
                raise ValueError(f'Invalid size {size}. Its dimension is bigger than its input.')
            elif len(size) == x_dim:
                return size
            else:
                if isinstance(self.mode, bm.BatchingMode):
                    size = (element,) + size
                if len(size) + 1 == x_dim:
                    if channel_axis is None:
                        raise ValueError('"channel_axis" should be provided.')
                    size = size[:channel_axis] + (element,) + size[channel_axis:]
                else:
                    raise ValueError(f'size {size} is invalid. Please provide more elements.')
                return size
        else:
            if isinstance(self.mode, bm.BatchingMode):
                return (element,) + tuple((size if d != channel_axis else element) for d in range(1, x_dim))
            else:
                return tuple((size if d != channel_axis else element) for d in range(0, x_dim))


[docs] class MaxPool(Pool): """Pools the input by taking the maximum over a window. Parameters:: kernel_size: int, sequence of int An integer, or a sequence of integers defining the window to reduce over. stride: int, sequence of int An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). padding: str, sequence of tuple Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. channel_axis: int, optional Axis of the spatial channels for which pooling is skipped, used to infer ``kernel_size`` or ``stride`` if they are an integer. mode: Mode The computation mode. name: optional, str The object name. """ def __init__( self, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1, padding: Union[str, Sequence[Tuple[int, int]]] = "VALID", channel_axis: Optional[int] = None, mode: bm.Mode = None, name: Optional[str] = None, ): super(MaxPool, self).__init__(init_value=-jax.numpy.inf, computation=jax.lax.max, kernel_size=kernel_size, stride=stride, padding=padding, channel_axis=channel_axis, mode=mode, name=name)
[docs] class MinPool(Pool): """Pools the input by taking the minimum over a window. Parameters:: kernel_size: int, sequence of int An integer, or a sequence of integers defining the window to reduce over. stride: int, sequence of int An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). padding: str, sequence of tuple Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. channel_axis: int, optional Axis of the spatial channels for which pooling is skipped, used to infer ``kernel_size`` or ``stride`` if they are an integer. mode: Mode The computation mode. name: optional, str The object name. """ def __init__( self, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1, padding: Union[str, Sequence[Tuple[int, int]]] = "VALID", channel_axis: Optional[int] = None, mode: bm.Mode = None, name: Optional[str] = None, ): super(MinPool, self).__init__(init_value=jax.numpy.inf, computation=jax.lax.min, kernel_size=kernel_size, stride=stride, padding=padding, channel_axis=channel_axis, mode=mode, name=name)
[docs] class AvgPool(Pool): """Pools the input by taking the average over a window. Parameters:: kernel_size: int, sequence of int An integer, or a sequence of integers defining the window to reduce over. stride: int, sequence of int An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). padding: str, sequence of tuple Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. channel_axis: int, optional Axis of the spatial channels for which pooling is skipped, used to infer ``kernel_size`` or ``stride`` if they are an integer. mode: Mode The computation mode. name: optional, str The object name. """ def __init__( self, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1, padding: Union[str, Sequence[Tuple[int, int]]] = "VALID", channel_axis: Optional[int] = None, mode: bm.Mode = None, name: Optional[str] = None, ): super(AvgPool, self).__init__(init_value=0., computation=jax.lax.add, kernel_size=kernel_size, stride=stride, padding=padding, channel_axis=channel_axis, mode=mode, name=name)
[docs] def update(self, x): x = bm.as_jax(x) window_shape = self._infer_shape(x.ndim, self.kernel_size) strides = self._infer_shape(x.ndim, self.stride) padding = (self.padding if isinstance(self.padding, str) else self._infer_shape(x.ndim, self.padding, element=(0, 0), element_type=(tuple, list))) pooled = jax.lax.reduce_window(bm.as_jax(x), init_value=self.init_value, computation=self.computation, window_dimensions=window_shape, window_strides=strides, padding=padding) if padding == "VALID": # Avoid the extra reduce_window. return pooled / np.prod(window_shape) else: # Count the number of valid entries at each input point, then use that for # computing average. Assumes that any two arrays of same shape will be # padded the same. window_counts = jax.lax.reduce_window(jnp.ones_like(bm.as_jax(x)), init_value=self.init_value, computation=self.computation, window_dimensions=window_shape, window_strides=strides, padding=padding) assert pooled.shape == window_counts.shape return pooled / window_counts
class _MaxPoolNd(Layer): def __init__( self, init_value, computation, pool_dim: int, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = None, padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", channel_axis: Optional[int] = -1, mode: bm.Mode = None, name: Optional[str] = None ): super().__init__(name=name, mode=mode) self.init_value = init_value self.computation = computation self.pool_dim = pool_dim # kernel_size if isinstance(kernel_size, int): kernel_size = (kernel_size,) * pool_dim elif isinstance(kernel_size, Sequence): check.is_sequence(kernel_size, elem_type=int) if len(kernel_size) != pool_dim: raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}') else: raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.') self.kernel_size = kernel_size # stride if stride is None: stride = kernel_size if isinstance(stride, int): stride = (stride,) * pool_dim elif isinstance(stride, Sequence): check.is_sequence(stride, elem_type=int) if len(stride) != pool_dim: raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(kernel_size)}') else: raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.') self.stride = stride # padding if isinstance(padding, str): if padding not in ("SAME", "VALID"): raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.") elif isinstance(padding, int): padding = [(padding, padding) for _ in range(pool_dim)] elif isinstance(padding, (list, tuple)): if isinstance(padding[0], int): if len(padding) == pool_dim: padding = [(x, x) for x in padding] else: raise ValueError(f'If padding is a sequence of ints, it ' f'should has the length of {pool_dim}.') else: if not all([isinstance(x, (tuple, list)) for x in padding]): raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}') if not all([len(x) == 2 for x in padding]): raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ") if len(padding) == 1: padding = tuple(padding) * pool_dim assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}' else: raise ValueError self.padding = padding # channel_axis self.channel_axis = check.is_integer(channel_axis, allow_none=True) def update(self, x): x = bm.as_jax(x) x_dim = self.pool_dim + (0 if self.channel_axis is None else 1) if x.ndim < x_dim: raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.') window_shape = self._infer_shape(x.ndim, self.kernel_size, 1) stride = self._infer_shape(x.ndim, self.stride, 1) padding = (self.padding if isinstance(self.padding, str) else self._infer_shape(x.ndim, self.padding, element=(0, 0))) r = jax.lax.reduce_window(bm.as_jax(x), init_value=self.init_value, computation=self.computation, window_dimensions=window_shape, window_strides=stride, padding=padding) return r def _infer_shape(self, x_dim, inputs, element): channel_axis = self.channel_axis if channel_axis and not 0 <= abs(channel_axis) < x_dim: raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions") if channel_axis and channel_axis < 0: channel_axis = x_dim + channel_axis all_dims = list(range(x_dim)) if channel_axis is not None: all_dims.pop(channel_axis) pool_dims = all_dims[-self.pool_dim:] results = [element] * x_dim for i, dim in enumerate(pool_dims): results[dim] = inputs[i] return results
[docs] class MaxPool1d(_MaxPoolNd): """Applies a 1D max pooling over an input signal composed of several input planes. Parameters:: kernel_size: int, sequence of int An integer, or a sequence of integers defining the window to reduce over. stride: int, sequence of int An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). padding: str, int, sequence of tuple Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. channel_axis: int, optional Axis of the spatial channels for which pooling is skipped. If ``None``, there is no channel axis. mode: Mode The computation mode. name: optional, str The object name. """ def __init__( self, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = None, padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", channel_axis: Optional[int] = -1, mode: bm.Mode = None, name: Optional[str] = None ): super().__init__(init_value=-jax.numpy.inf, computation=jax.lax.max, pool_dim=1, kernel_size=kernel_size, stride=stride, padding=padding, channel_axis=channel_axis, name=name, mode=mode)
[docs] class MaxPool2d(_MaxPoolNd): """Applies a 1D max pooling over an input signal composed of several input planes. Parameters:: kernel_size: int, sequence of int An integer, or a sequence of integers defining the window to reduce over. stride: int, sequence of int An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). padding: str, int, sequence of tuple Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. channel_axis: int, optional Axis of the spatial channels for which pooling is skipped. If ``None``, there is no channel axis. mode: Mode The computation mode. name: optional, str The object name. """ def __init__( self, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = None, padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", channel_axis: Optional[int] = -1, mode: bm.Mode = None, name: Optional[str] = None ): super().__init__(init_value=-jax.numpy.inf, computation=jax.lax.max, pool_dim=2, kernel_size=kernel_size, stride=stride, padding=padding, channel_axis=channel_axis, name=name, mode=mode)
[docs] class MaxPool3d(_MaxPoolNd): """Applies a 1D max pooling over an input signal composed of several input planes. Parameters:: kernel_size: int, sequence of int An integer, or a sequence of integers defining the window to reduce over. stride: int, sequence of int An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). padding: str, int, sequence of tuple Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. channel_axis: int, optional Axis of the spatial channels for which pooling is skipped. If ``None``, there is no channel axis. mode: Mode The computation mode. name: optional, str The object name. """ def __init__( self, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = None, padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID", channel_axis: Optional[int] = -1, mode: bm.Mode = None, name: Optional[str] = None ): super().__init__(init_value=-jax.numpy.inf, computation=jax.lax.max, pool_dim=3, kernel_size=kernel_size, stride=stride, padding=padding, channel_axis=channel_axis, name=name, mode=mode)
class _AvgPoolNd(_MaxPoolNd): def update(self, x): x = bm.as_jax(x) x_dim = self.pool_dim + (0 if self.channel_axis is None else 1) if x.ndim < x_dim: raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.') dims = self._infer_shape(x.ndim, self.kernel_size, 1) stride = self._infer_shape(x.ndim, self.stride, 1) padding = (self.padding if isinstance(self.padding, str) else self._infer_shape(x.ndim, self.padding, element=(0, 0))) pooled = jax.lax.reduce_window(bm.as_jax(x), init_value=self.init_value, computation=self.computation, window_dimensions=dims, window_strides=stride, padding=padding) if padding == "VALID": # Avoid the extra reduce_window. return pooled / np.prod(dims) else: # Count the number of valid entries at each input point, then use that for # computing average. Assumes that any two arrays of same shape will be # padded the same. window_counts = jax.lax.reduce_window(jnp.ones_like(bm.as_jax(x)), init_value=self.init_value, computation=self.computation, window_dimensions=dims, window_strides=stride, padding=padding) assert pooled.shape == window_counts.shape return pooled / window_counts
[docs] class AvgPool1d(_AvgPoolNd): """Applies a 1D average pooling over an input signal composed of several input planes. Parameters:: kernel_size: int, sequence of int An integer, or a sequence of integers defining the window to reduce over. stride: int, sequence of int An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). padding: str, int, sequence of tuple Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. channel_axis: int, optional Axis of the spatial channels for which pooling is skipped. If ``None``, there is no channel axis. mode: Mode The computation mode. name: optional, str The object name. """ def __init__( self, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1, padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", channel_axis: Optional[int] = -1, mode: bm.Mode = None, name: Optional[str] = None ): super().__init__(init_value=0., computation=jax.lax.add, pool_dim=1, kernel_size=kernel_size, stride=stride, padding=padding, channel_axis=channel_axis, name=name, mode=mode)
[docs] class AvgPool2d(_AvgPoolNd): """Applies a 2D average pooling over an input signal composed of several input planes. Parameters:: kernel_size: int, sequence of int An integer, or a sequence of integers defining the window to reduce over. stride: int, sequence of int An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). padding: str, int, sequence of tuple Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. channel_axis: int, optional Axis of the spatial channels for which pooling is skipped. If ``None``, there is no channel axis. mode: Mode The computation mode. name: optional, str The object name. """ def __init__( self, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1, padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", channel_axis: Optional[int] = -1, mode: bm.Mode = None, name: Optional[str] = None ): super().__init__(init_value=0., computation=jax.lax.add, pool_dim=2, kernel_size=kernel_size, stride=stride, padding=padding, channel_axis=channel_axis, name=name, mode=mode)
[docs] class AvgPool3d(_AvgPoolNd): """Applies a 3D average pooling over an input signal composed of several input planes. Parameters:: kernel_size: int, sequence of int An integer, or a sequence of integers defining the window to reduce over. stride: int, sequence of int An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). padding: str, int, sequence of tuple Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. channel_axis: int, optional Axis of the spatial channels for which pooling is skipped. If ``None``, there is no channel axis. mode: Mode The computation mode. name: optional, str The object name. """ def __init__( self, kernel_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]] = 1, padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", channel_axis: Optional[int] = -1, mode: bm.Mode = None, name: Optional[str] = None ): super().__init__(init_value=0., computation=jax.lax.add, pool_dim=3, kernel_size=kernel_size, stride=stride, padding=padding, channel_axis=channel_axis, name=name, mode=mode)
def _adaptive_pool1d(x, target_size: int, operation: Callable): """Adaptive pool 1D. Args: x: The input. Should be a JAX array of shape `(dim,)`. target_size: The shape of the output after the pooling operation `(target_size,)`. operation: The pooling operation to be performed on the input array. Returns: A JAX array of shape `(target_size, )`. """ x = bm.as_jax(x) size = jnp.size(x) num_head_arrays = size % target_size num_block = size // target_size if num_head_arrays != 0: head_end_index = num_head_arrays * (num_block + 1) heads = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1)) tails = jax.vmap(operation)(x[head_end_index:].reshape(-1, num_block)) outs = jnp.concatenate([heads, tails]) else: outs = jax.vmap(operation)(x.reshape(-1, num_block)) return outs def _generate_vmap(fun: Callable, map_axes: List[int]): map_axes = sorted(map_axes) for axis in map_axes: fun = jax.vmap(fun, in_axes=(axis, None, None), out_axes=axis) return fun class AdaptivePool(Layer): """General N dimensional adaptive down-sampling to a target shape. Parameters:: target_shape: int, sequence of int The target output shape. num_spatial_dims: int The number of spatial dimensions. channel_axis: int, optional Axis of the spatial channels for which pooling is skipped. If ``None``, there is no channel axis. operation: Callable The down-sampling operation. name: str The class name. mode: Mode The computing mode. """ def __init__( self, target_shape: Union[int, Sequence[int]], num_spatial_dims: int, operation: Callable, channel_axis: Optional[int] = -1, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): super().__init__(name=name, mode=mode) self.channel_axis = channel_axis self.operation = operation if isinstance(target_shape, int): self.target_shape = (target_shape,) * num_spatial_dims elif isinstance(target_shape, Sequence) and (len(target_shape) == num_spatial_dims): self.target_shape = target_shape else: raise ValueError("`target_size` must either be an int or tuple of length " f"{num_spatial_dims} containing ints.") def update(self, x): """Input-output mapping. Parameters:: x: Array Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)` or `(..., dim_1, dim_2)`. """ x = bm.as_jax(x) # channel axis channel_axis = self.channel_axis if channel_axis: if not 0 <= abs(channel_axis) < x.ndim: raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}") if channel_axis < 0: channel_axis = x.ndim + channel_axis # input dimension if (x.ndim - (0 if channel_axis is None else 1)) < len(self.target_shape): raise ValueError(f"Invalid input dimension. Except >={len(self.target_shape)} " f"dimensions (channel_axis={self.channel_axis}). " f"But got {x.ndim} dimensions.") # pooling dimensions pool_dims = list(range(x.ndim)) if channel_axis: pool_dims.pop(channel_axis) # pooling for i, di in enumerate(pool_dims[-len(self.target_shape):]): poo_axes = [j for j in range(x.ndim) if j != di] op = _generate_vmap(_adaptive_pool1d, poo_axes) x = op(x, self.target_shape[i], self.operation) return x
[docs] class AdaptiveAvgPool1d(AdaptivePool): """Adaptive one-dimensional average down-sampling. Parameters:: target_shape: int, sequence of int The target output shape. channel_axis: int, optional Axis of the spatial channels for which pooling is skipped. If ``None``, there is no channel axis. name: str The class name. mode: Mode The computing mode. """ def __init__(self, target_shape: Union[int, Sequence[int]], channel_axis: Optional[int] = -1, name: Optional[str] = None, mode: Optional[bm.Mode] = None): super().__init__(target_shape, channel_axis=channel_axis, num_spatial_dims=1, operation=jnp.mean, name=name, mode=mode)
[docs] class AdaptiveAvgPool2d(AdaptivePool): """Adaptive two-dimensional average down-sampling. Parameters:: target_shape: int, sequence of int The target output shape. channel_axis: int, optional Axis of the spatial channels for which pooling is skipped. If ``None``, there is no channel axis. name: str The class name. mode: Mode The computing mode. """ def __init__(self, target_shape: Union[int, Sequence[int]], channel_axis: Optional[int] = -1, name: Optional[str] = None, mode: Optional[bm.Mode] = None): super().__init__(target_shape, channel_axis=channel_axis, num_spatial_dims=2, operation=jnp.mean, name=name, mode=mode)
[docs] class AdaptiveAvgPool3d(AdaptivePool): """Adaptive three-dimensional average down-sampling. Parameters:: target_shape: int, sequence of int The target output shape. channel_axis: int, optional Axis of the spatial channels for which pooling is skipped. If ``None``, there is no channel axis. name: str The class name. mode: Mode The computing mode. """ def __init__(self, target_shape: Union[int, Sequence[int]], channel_axis: Optional[int] = -1, name: Optional[str] = None, mode: Optional[bm.Mode] = None): super().__init__(target_shape, channel_axis=channel_axis, num_spatial_dims=3, operation=jnp.mean, name=name, mode=mode)
[docs] class AdaptiveMaxPool1d(AdaptivePool): """Adaptive one-dimensional maximum down-sampling. Parameters:: target_shape: int, sequence of int The target output shape. channel_axis: int, optional Axis of the spatial channels for which pooling is skipped. If ``None``, there is no channel axis. name: str The class name. mode: Mode The computing mode. """ def __init__(self, target_shape: Union[int, Sequence[int]], channel_axis: Optional[int] = -1, name: Optional[str] = None, mode: Optional[bm.Mode] = None): super().__init__(target_shape, channel_axis=channel_axis, num_spatial_dims=1, operation=jnp.max, name=name, mode=mode)
[docs] class AdaptiveMaxPool2d(AdaptivePool): """Adaptive two-dimensional maximum down-sampling. Parameters:: target_shape: int, sequence of int The target output shape. channel_axis: int, optional Axis of the spatial channels for which pooling is skipped. If ``None``, there is no channel axis. name: str The class name. mode: Mode The computing mode. """ def __init__(self, target_shape: Union[int, Sequence[int]], channel_axis: Optional[int] = -1, name: Optional[str] = None, mode: Optional[bm.Mode] = None): super().__init__(target_shape, channel_axis=channel_axis, num_spatial_dims=2, operation=jnp.max, name=name, mode=mode)
[docs] class AdaptiveMaxPool3d(AdaptivePool): """Adaptive three-dimensional maximum down-sampling. Parameters:: target_shape: int, sequence of int The target output shape. channel_axis: int, optional Axis of the spatial channels for which pooling is skipped. If ``None``, there is no channel axis. name: str The class name. mode: Mode The computing mode. """ def __init__(self, target_shape: Union[int, Sequence[int]], channel_axis: Optional[int] = -1, name: Optional[str] = None, mode: Optional[bm.Mode] = None): super().__init__(target_shape, channel_axis=channel_axis, num_spatial_dims=3, operation=jnp.max, name=name, mode=mode)