# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Union, Tuple, Sequence, Optional, Callable, List, Any
import jax
import jax.numpy as jnp
import numpy as np
from brainpy import math as bm, check
from brainpy.dnn.base import Layer
__all__ = [
'MaxPool',
'MinPool',
'AvgPool',
'AvgPool1d',
'AvgPool2d',
'AvgPool3d',
'MaxPool1d',
'MaxPool2d',
'MaxPool3d',
'AdaptiveAvgPool1d',
'AdaptiveAvgPool2d',
'AdaptiveAvgPool3d',
'AdaptiveMaxPool1d',
'AdaptiveMaxPool2d',
'AdaptiveMaxPool3d',
]
class Pool(Layer):
"""Pooling functions are implemented using the ReduceWindow XLA op.
Parameters::
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped,
used to infer ``kernel_size`` or ``stride`` if they are an integer.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
init_value,
computation,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]],
padding: Union[str, Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = None,
mode: bm.Mode = None,
name: Optional[str] = None,
):
super(Pool, self).__init__(mode=mode, name=name)
self.init_value = init_value
self.computation = computation
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.channel_axis = channel_axis
if isinstance(padding, str):
if padding not in ("SAME", "VALID"):
raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
else:
assert all([isinstance(x, (tuple, list)) for x in padding]), \
f'padding should be sequence of Tuple[int, int]. {padding}'
assert all([len(x) == 2 for x in padding]), f"each entry in padding {padding} must be length 2"
def update(self, x):
x = bm.as_jax(x)
window_shape = self._infer_shape(x.ndim, self.kernel_size)
stride = self._infer_shape(x.ndim, self.stride)
padding = (self.padding
if isinstance(self.padding, str) else
self._infer_shape(x.ndim, self.padding, element=(0, 0), element_type=(tuple, list)))
r = jax.lax.reduce_window(bm.as_jax(x),
init_value=self.init_value,
computation=self.computation,
window_dimensions=window_shape,
window_strides=stride,
padding=padding)
return r
def _infer_shape(self,
x_dim: int,
size: Union[Any, Sequence[Any]],
element: Any = 1,
element_type: Union[type, Sequence[type]] = int):
"""Infer shape for pooling window or stride."""
# channel axis
channel_axis = self.channel_axis
if channel_axis and not 0 <= abs(channel_axis) < x_dim:
raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
if channel_axis and channel_axis < 0:
channel_axis = x_dim + channel_axis
if isinstance(size, (tuple, list)) and isinstance(size[0], element_type):
size = tuple(size)
if len(size) > x_dim:
raise ValueError(f'Invalid size {size}. Its dimension is bigger than its input.')
elif len(size) == x_dim:
return size
else:
if isinstance(self.mode, bm.BatchingMode):
size = (element,) + size
if len(size) + 1 == x_dim:
if channel_axis is None:
raise ValueError('"channel_axis" should be provided.')
size = size[:channel_axis] + (element,) + size[channel_axis:]
else:
raise ValueError(f'size {size} is invalid. Please provide more elements.')
return size
else:
if isinstance(self.mode, bm.BatchingMode):
return (element,) + tuple((size if d != channel_axis else element) for d in range(1, x_dim))
else:
return tuple((size if d != channel_axis else element) for d in range(0, x_dim))
[docs]
class MaxPool(Pool):
"""Pools the input by taking the maximum over a window.
Parameters::
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped,
used to infer ``kernel_size`` or ``stride`` if they are an integer.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
padding: Union[str, Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = None,
mode: bm.Mode = None,
name: Optional[str] = None,
):
super(MaxPool, self).__init__(init_value=-jax.numpy.inf,
computation=jax.lax.max,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
mode=mode,
name=name)
[docs]
class MinPool(Pool):
"""Pools the input by taking the minimum over a window.
Parameters::
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped,
used to infer ``kernel_size`` or ``stride`` if they are an integer.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
padding: Union[str, Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = None,
mode: bm.Mode = None,
name: Optional[str] = None,
):
super(MinPool, self).__init__(init_value=jax.numpy.inf,
computation=jax.lax.min,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
mode=mode,
name=name)
[docs]
class AvgPool(Pool):
"""Pools the input by taking the average over a window.
Parameters::
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped,
used to infer ``kernel_size`` or ``stride`` if they are an integer.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
padding: Union[str, Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = None,
mode: bm.Mode = None,
name: Optional[str] = None,
):
super(AvgPool, self).__init__(init_value=0.,
computation=jax.lax.add,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
mode=mode,
name=name)
[docs]
def update(self, x):
x = bm.as_jax(x)
window_shape = self._infer_shape(x.ndim, self.kernel_size)
strides = self._infer_shape(x.ndim, self.stride)
padding = (self.padding if isinstance(self.padding, str) else
self._infer_shape(x.ndim, self.padding, element=(0, 0), element_type=(tuple, list)))
pooled = jax.lax.reduce_window(bm.as_jax(x),
init_value=self.init_value,
computation=self.computation,
window_dimensions=window_shape,
window_strides=strides,
padding=padding)
if padding == "VALID":
# Avoid the extra reduce_window.
return pooled / np.prod(window_shape)
else:
# Count the number of valid entries at each input point, then use that for
# computing average. Assumes that any two arrays of same shape will be
# padded the same.
window_counts = jax.lax.reduce_window(jnp.ones_like(bm.as_jax(x)),
init_value=self.init_value,
computation=self.computation,
window_dimensions=window_shape,
window_strides=strides,
padding=padding)
assert pooled.shape == window_counts.shape
return pooled / window_counts
class _MaxPoolNd(Layer):
def __init__(
self,
init_value,
computation,
pool_dim: int,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = None,
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = -1,
mode: bm.Mode = None,
name: Optional[str] = None
):
super().__init__(name=name, mode=mode)
self.init_value = init_value
self.computation = computation
self.pool_dim = pool_dim
# kernel_size
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * pool_dim
elif isinstance(kernel_size, Sequence):
check.is_sequence(kernel_size, elem_type=int)
if len(kernel_size) != pool_dim:
raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
else:
raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
self.kernel_size = kernel_size
# stride
if stride is None:
stride = kernel_size
if isinstance(stride, int):
stride = (stride,) * pool_dim
elif isinstance(stride, Sequence):
check.is_sequence(stride, elem_type=int)
if len(stride) != pool_dim:
raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
else:
raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
self.stride = stride
# padding
if isinstance(padding, str):
if padding not in ("SAME", "VALID"):
raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
elif isinstance(padding, int):
padding = [(padding, padding) for _ in range(pool_dim)]
elif isinstance(padding, (list, tuple)):
if isinstance(padding[0], int):
if len(padding) == pool_dim:
padding = [(x, x) for x in padding]
else:
raise ValueError(f'If padding is a sequence of ints, it '
f'should has the length of {pool_dim}.')
else:
if not all([isinstance(x, (tuple, list)) for x in padding]):
raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
if not all([len(x) == 2 for x in padding]):
raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
if len(padding) == 1:
padding = tuple(padding) * pool_dim
assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
else:
raise ValueError
self.padding = padding
# channel_axis
self.channel_axis = check.is_integer(channel_axis, allow_none=True)
def update(self, x):
x = bm.as_jax(x)
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
if x.ndim < x_dim:
raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
stride = self._infer_shape(x.ndim, self.stride, 1)
padding = (self.padding
if isinstance(self.padding, str) else
self._infer_shape(x.ndim, self.padding, element=(0, 0)))
r = jax.lax.reduce_window(bm.as_jax(x),
init_value=self.init_value,
computation=self.computation,
window_dimensions=window_shape,
window_strides=stride,
padding=padding)
return r
def _infer_shape(self, x_dim, inputs, element):
channel_axis = self.channel_axis
if channel_axis and not 0 <= abs(channel_axis) < x_dim:
raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
if channel_axis and channel_axis < 0:
channel_axis = x_dim + channel_axis
all_dims = list(range(x_dim))
if channel_axis is not None:
all_dims.pop(channel_axis)
pool_dims = all_dims[-self.pool_dim:]
results = [element] * x_dim
for i, dim in enumerate(pool_dims):
results[dim] = inputs[i]
return results
[docs]
class MaxPool1d(_MaxPoolNd):
"""Applies a 1D max pooling over an input signal composed of several input
planes.
Parameters::
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = None,
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = -1,
mode: bm.Mode = None,
name: Optional[str] = None
):
super().__init__(init_value=-jax.numpy.inf,
computation=jax.lax.max,
pool_dim=1,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
name=name,
mode=mode)
[docs]
class MaxPool2d(_MaxPoolNd):
"""Applies a 1D max pooling over an input signal composed of several input
planes.
Parameters::
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = None,
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = -1,
mode: bm.Mode = None,
name: Optional[str] = None
):
super().__init__(init_value=-jax.numpy.inf,
computation=jax.lax.max,
pool_dim=2,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
name=name, mode=mode)
[docs]
class MaxPool3d(_MaxPoolNd):
"""Applies a 1D max pooling over an input signal composed of several input
planes.
Parameters::
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = None,
padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = -1,
mode: bm.Mode = None,
name: Optional[str] = None
):
super().__init__(init_value=-jax.numpy.inf,
computation=jax.lax.max,
pool_dim=3,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
name=name, mode=mode)
class _AvgPoolNd(_MaxPoolNd):
def update(self, x):
x = bm.as_jax(x)
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
if x.ndim < x_dim:
raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
dims = self._infer_shape(x.ndim, self.kernel_size, 1)
stride = self._infer_shape(x.ndim, self.stride, 1)
padding = (self.padding
if isinstance(self.padding, str) else
self._infer_shape(x.ndim, self.padding, element=(0, 0)))
pooled = jax.lax.reduce_window(bm.as_jax(x),
init_value=self.init_value,
computation=self.computation,
window_dimensions=dims,
window_strides=stride,
padding=padding)
if padding == "VALID":
# Avoid the extra reduce_window.
return pooled / np.prod(dims)
else:
# Count the number of valid entries at each input point, then use that for
# computing average. Assumes that any two arrays of same shape will be
# padded the same.
window_counts = jax.lax.reduce_window(jnp.ones_like(bm.as_jax(x)),
init_value=self.init_value,
computation=self.computation,
window_dimensions=dims,
window_strides=stride,
padding=padding)
assert pooled.shape == window_counts.shape
return pooled / window_counts
[docs]
class AvgPool1d(_AvgPoolNd):
"""Applies a 1D average pooling over an input signal composed of several input
planes.
Parameters::
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = -1,
mode: bm.Mode = None,
name: Optional[str] = None
):
super().__init__(init_value=0.,
computation=jax.lax.add,
pool_dim=1,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
name=name,
mode=mode)
[docs]
class AvgPool2d(_AvgPoolNd):
"""Applies a 2D average pooling over an input signal composed of several input
planes.
Parameters::
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = -1,
mode: bm.Mode = None,
name: Optional[str] = None
):
super().__init__(init_value=0.,
computation=jax.lax.add,
pool_dim=2,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
name=name,
mode=mode)
[docs]
class AvgPool3d(_AvgPoolNd):
"""Applies a 3D average pooling over an input signal composed of several input
planes.
Parameters::
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = -1,
mode: bm.Mode = None,
name: Optional[str] = None
):
super().__init__(init_value=0.,
computation=jax.lax.add,
pool_dim=3,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
name=name,
mode=mode)
def _adaptive_pool1d(x, target_size: int, operation: Callable):
"""Adaptive pool 1D.
Args:
x: The input. Should be a JAX array of shape `(dim,)`.
target_size: The shape of the output after the pooling operation `(target_size,)`.
operation: The pooling operation to be performed on the input array.
Returns:
A JAX array of shape `(target_size, )`.
"""
x = bm.as_jax(x)
size = jnp.size(x)
num_head_arrays = size % target_size
num_block = size // target_size
if num_head_arrays != 0:
head_end_index = num_head_arrays * (num_block + 1)
heads = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1))
tails = jax.vmap(operation)(x[head_end_index:].reshape(-1, num_block))
outs = jnp.concatenate([heads, tails])
else:
outs = jax.vmap(operation)(x.reshape(-1, num_block))
return outs
def _generate_vmap(fun: Callable, map_axes: List[int]):
map_axes = sorted(map_axes)
for axis in map_axes:
fun = jax.vmap(fun, in_axes=(axis, None, None), out_axes=axis)
return fun
class AdaptivePool(Layer):
"""General N dimensional adaptive down-sampling to a target shape.
Parameters::
target_shape: int, sequence of int
The target output shape.
num_spatial_dims: int
The number of spatial dimensions.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
operation: Callable
The down-sampling operation.
name: str
The class name.
mode: Mode
The computing mode.
"""
def __init__(
self,
target_shape: Union[int, Sequence[int]],
num_spatial_dims: int,
operation: Callable,
channel_axis: Optional[int] = -1,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
super().__init__(name=name, mode=mode)
self.channel_axis = channel_axis
self.operation = operation
if isinstance(target_shape, int):
self.target_shape = (target_shape,) * num_spatial_dims
elif isinstance(target_shape, Sequence) and (len(target_shape) == num_spatial_dims):
self.target_shape = target_shape
else:
raise ValueError("`target_size` must either be an int or tuple of length "
f"{num_spatial_dims} containing ints.")
def update(self, x):
"""Input-output mapping.
Parameters::
x: Array
Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)`
or `(..., dim_1, dim_2)`.
"""
x = bm.as_jax(x)
# channel axis
channel_axis = self.channel_axis
if channel_axis:
if not 0 <= abs(channel_axis) < x.ndim:
raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
if channel_axis < 0:
channel_axis = x.ndim + channel_axis
# input dimension
if (x.ndim - (0 if channel_axis is None else 1)) < len(self.target_shape):
raise ValueError(f"Invalid input dimension. Except >={len(self.target_shape)} "
f"dimensions (channel_axis={self.channel_axis}). "
f"But got {x.ndim} dimensions.")
# pooling dimensions
pool_dims = list(range(x.ndim))
if channel_axis:
pool_dims.pop(channel_axis)
# pooling
for i, di in enumerate(pool_dims[-len(self.target_shape):]):
poo_axes = [j for j in range(x.ndim) if j != di]
op = _generate_vmap(_adaptive_pool1d, poo_axes)
x = op(x, self.target_shape[i], self.operation)
return x
[docs]
class AdaptiveAvgPool1d(AdaptivePool):
"""Adaptive one-dimensional average down-sampling.
Parameters::
target_shape: int, sequence of int
The target output shape.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
name: str
The class name.
mode: Mode
The computing mode.
"""
def __init__(self,
target_shape: Union[int, Sequence[int]],
channel_axis: Optional[int] = -1,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None):
super().__init__(target_shape,
channel_axis=channel_axis,
num_spatial_dims=1,
operation=jnp.mean,
name=name,
mode=mode)
[docs]
class AdaptiveAvgPool2d(AdaptivePool):
"""Adaptive two-dimensional average down-sampling.
Parameters::
target_shape: int, sequence of int
The target output shape.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
name: str
The class name.
mode: Mode
The computing mode.
"""
def __init__(self,
target_shape: Union[int, Sequence[int]],
channel_axis: Optional[int] = -1,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None):
super().__init__(target_shape,
channel_axis=channel_axis,
num_spatial_dims=2,
operation=jnp.mean,
name=name,
mode=mode)
[docs]
class AdaptiveAvgPool3d(AdaptivePool):
"""Adaptive three-dimensional average down-sampling.
Parameters::
target_shape: int, sequence of int
The target output shape.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
name: str
The class name.
mode: Mode
The computing mode.
"""
def __init__(self,
target_shape: Union[int, Sequence[int]],
channel_axis: Optional[int] = -1,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None):
super().__init__(target_shape,
channel_axis=channel_axis,
num_spatial_dims=3,
operation=jnp.mean,
name=name,
mode=mode)
[docs]
class AdaptiveMaxPool1d(AdaptivePool):
"""Adaptive one-dimensional maximum down-sampling.
Parameters::
target_shape: int, sequence of int
The target output shape.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
name: str
The class name.
mode: Mode
The computing mode.
"""
def __init__(self,
target_shape: Union[int, Sequence[int]],
channel_axis: Optional[int] = -1,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None):
super().__init__(target_shape,
channel_axis=channel_axis,
num_spatial_dims=1,
operation=jnp.max,
name=name,
mode=mode)
[docs]
class AdaptiveMaxPool2d(AdaptivePool):
"""Adaptive two-dimensional maximum down-sampling.
Parameters::
target_shape: int, sequence of int
The target output shape.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
name: str
The class name.
mode: Mode
The computing mode.
"""
def __init__(self,
target_shape: Union[int, Sequence[int]],
channel_axis: Optional[int] = -1,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None):
super().__init__(target_shape,
channel_axis=channel_axis,
num_spatial_dims=2,
operation=jnp.max,
name=name,
mode=mode)
[docs]
class AdaptiveMaxPool3d(AdaptivePool):
"""Adaptive three-dimensional maximum down-sampling.
Parameters::
target_shape: int, sequence of int
The target output shape.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
name: str
The class name.
mode: Mode
The computing mode.
"""
def __init__(self,
target_shape: Union[int, Sequence[int]],
channel_axis: Optional[int] = -1,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None):
super().__init__(target_shape,
channel_axis=channel_axis,
num_spatial_dims=3,
operation=jnp.max,
name=name,
mode=mode)