# -*- coding: utf-8 -*-
from typing import Union, Tuple, Sequence, Optional, Callable, List, Any
import jax
import jax.numpy as jnp
import numpy as np
from brainpy import math as bm, check
from brainpy._src.dnn.base import Layer
__all__ = [
'MaxPool',
'MinPool',
'AvgPool',
'AvgPool1d',
'AvgPool2d',
'AvgPool3d',
'MaxPool1d',
'MaxPool2d',
'MaxPool3d',
'AdaptiveAvgPool1d',
'AdaptiveAvgPool2d',
'AdaptiveAvgPool3d',
'AdaptiveMaxPool1d',
'AdaptiveMaxPool2d',
'AdaptiveMaxPool3d',
]
class Pool(Layer):
"""Pooling functions are implemented using the ReduceWindow XLA op.
Parameters
----------
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped,
used to infer ``kernel_size`` or ``stride`` if they are an integer.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
init_value,
computation,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]],
padding: Union[str, Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = None,
mode: bm.Mode = None,
name: Optional[str] = None,
):
super(Pool, self).__init__(mode=mode, name=name)
self.init_value = init_value
self.computation = computation
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.channel_axis = channel_axis
if isinstance(padding, str):
if padding not in ("SAME", "VALID"):
raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
else:
assert all([isinstance(x, (tuple, list)) for x in padding]), \
f'padding should be sequence of Tuple[int, int]. {padding}'
assert all([len(x) == 2 for x in padding]), f"each entry in padding {padding} must be length 2"
def update(self, x):
x = bm.as_jax(x)
window_shape = self._infer_shape(x.ndim, self.kernel_size)
stride = self._infer_shape(x.ndim, self.stride)
padding = (self.padding
if isinstance(self.padding, str) else
self._infer_shape(x.ndim, self.padding, element=(0, 0), element_type=(tuple, list)))
r = jax.lax.reduce_window(bm.as_jax(x),
init_value=self.init_value,
computation=self.computation,
window_dimensions=window_shape,
window_strides=stride,
padding=padding)
return r
def _infer_shape(self,
x_dim: int,
size: Union[Any, Sequence[Any]],
element: Any = 1,
element_type: Union[type, Sequence[type]] = int):
"""Infer shape for pooling window or stride."""
# channel axis
channel_axis = self.channel_axis
if channel_axis and not 0 <= abs(channel_axis) < x_dim:
raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
if channel_axis and channel_axis < 0:
channel_axis = x_dim + channel_axis
if isinstance(size, (tuple, list)) and isinstance(size[0], element_type):
size = tuple(size)
if len(size) > x_dim:
raise ValueError(f'Invalid size {size}. Its dimension is bigger than its input.')
elif len(size) == x_dim:
return size
else:
if isinstance(self.mode, bm.BatchingMode):
size = (element,) + size
if len(size) + 1 == x_dim:
if channel_axis is None:
raise ValueError('"channel_axis" should be provided.')
size = size[:channel_axis] + (element,) + size[channel_axis:]
else:
raise ValueError(f'size {size} is invalid. Please provide more elements.')
return size
else:
if isinstance(self.mode, bm.BatchingMode):
return (element,) + tuple((size if d != channel_axis else element) for d in range(1, x_dim))
else:
return tuple((size if d != channel_axis else element) for d in range(0, x_dim))
[docs]
class MaxPool(Pool):
"""Pools the input by taking the maximum over a window.
Parameters
----------
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped,
used to infer ``kernel_size`` or ``stride`` if they are an integer.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
padding: Union[str, Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = None,
mode: bm.Mode = None,
name: Optional[str] = None,
):
super(MaxPool, self).__init__(init_value=-jax.numpy.inf,
computation=jax.lax.max,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
mode=mode,
name=name)
[docs]
class MinPool(Pool):
"""Pools the input by taking the minimum over a window.
Parameters
----------
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped,
used to infer ``kernel_size`` or ``stride`` if they are an integer.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
padding: Union[str, Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = None,
mode: bm.Mode = None,
name: Optional[str] = None,
):
super(MinPool, self).__init__(init_value=jax.numpy.inf,
computation=jax.lax.min,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
mode=mode,
name=name)
[docs]
class AvgPool(Pool):
"""Pools the input by taking the average over a window.
Parameters
----------
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped,
used to infer ``kernel_size`` or ``stride`` if they are an integer.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
padding: Union[str, Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = None,
mode: bm.Mode = None,
name: Optional[str] = None,
):
super(AvgPool, self).__init__(init_value=0.,
computation=jax.lax.add,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
mode=mode,
name=name)
[docs]
def update(self, x):
x = bm.as_jax(x)
window_shape = self._infer_shape(x.ndim, self.kernel_size)
strides = self._infer_shape(x.ndim, self.stride)
padding = (self.padding if isinstance(self.padding, str) else
self._infer_shape(x.ndim, self.padding, element=(0, 0), element_type=(tuple, list)))
pooled = jax.lax.reduce_window(bm.as_jax(x),
init_value=self.init_value,
computation=self.computation,
window_dimensions=window_shape,
window_strides=strides,
padding=padding)
if padding == "VALID":
# Avoid the extra reduce_window.
return pooled / np.prod(window_shape)
else:
# Count the number of valid entries at each input point, then use that for
# computing average. Assumes that any two arrays of same shape will be
# padded the same.
window_counts = jax.lax.reduce_window(jnp.ones_like(bm.as_jax(x)),
init_value=self.init_value,
computation=self.computation,
window_dimensions=window_shape,
window_strides=strides,
padding=padding)
assert pooled.shape == window_counts.shape
return pooled / window_counts
class _MaxPoolNd(Layer):
def __init__(
self,
init_value,
computation,
pool_dim: int,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = None,
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = -1,
mode: bm.Mode = None,
name: Optional[str] = None
):
super().__init__(name=name, mode=mode)
self.init_value = init_value
self.computation = computation
self.pool_dim = pool_dim
# kernel_size
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * pool_dim
elif isinstance(kernel_size, Sequence):
check.is_sequence(kernel_size, elem_type=int)
if len(kernel_size) != pool_dim:
raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
else:
raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.')
self.kernel_size = kernel_size
# stride
if stride is None:
stride = kernel_size
if isinstance(stride, int):
stride = (stride,) * pool_dim
elif isinstance(stride, Sequence):
check.is_sequence(stride, elem_type=int)
if len(stride) != pool_dim:
raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(kernel_size)}')
else:
raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.')
self.stride = stride
# padding
if isinstance(padding, str):
if padding not in ("SAME", "VALID"):
raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
elif isinstance(padding, int):
padding = [(padding, padding) for _ in range(pool_dim)]
elif isinstance(padding, (list, tuple)):
if isinstance(padding[0], int):
if len(padding) == pool_dim:
padding = [(x, x) for x in padding]
else:
raise ValueError(f'If padding is a sequence of ints, it '
f'should has the length of {pool_dim}.')
else:
if not all([isinstance(x, (tuple, list)) for x in padding]):
raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}')
if not all([len(x) == 2 for x in padding]):
raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ")
if len(padding) == 1:
padding = tuple(padding) * pool_dim
assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}'
else:
raise ValueError
self.padding = padding
# channel_axis
self.channel_axis = check.is_integer(channel_axis, allow_none=True)
def update(self, x):
x = bm.as_jax(x)
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
if x.ndim < x_dim:
raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
window_shape = self._infer_shape(x.ndim, self.kernel_size, 1)
stride = self._infer_shape(x.ndim, self.stride, 1)
padding = (self.padding
if isinstance(self.padding, str) else
self._infer_shape(x.ndim, self.padding, element=(0, 0)))
r = jax.lax.reduce_window(bm.as_jax(x),
init_value=self.init_value,
computation=self.computation,
window_dimensions=window_shape,
window_strides=stride,
padding=padding)
return r
def _infer_shape(self, x_dim, inputs, element):
channel_axis = self.channel_axis
if channel_axis and not 0 <= abs(channel_axis) < x_dim:
raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
if channel_axis and channel_axis < 0:
channel_axis = x_dim + channel_axis
all_dims = list(range(x_dim))
if channel_axis is not None:
all_dims.pop(channel_axis)
pool_dims = all_dims[-self.pool_dim:]
results = [element] * x_dim
for i, dim in enumerate(pool_dims):
results[dim] = inputs[i]
return results
[docs]
class MaxPool1d(_MaxPoolNd):
"""Applies a 1D max pooling over an input signal composed of several input
planes.
Parameters
----------
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = None,
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = -1,
mode: bm.Mode = None,
name: Optional[str] = None
):
super().__init__(init_value=-jax.numpy.inf,
computation=jax.lax.max,
pool_dim=1,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
name=name,
mode=mode)
[docs]
class MaxPool2d(_MaxPoolNd):
"""Applies a 1D max pooling over an input signal composed of several input
planes.
Parameters
----------
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = None,
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = -1,
mode: bm.Mode = None,
name: Optional[str] = None
):
super().__init__(init_value=-jax.numpy.inf,
computation=jax.lax.max,
pool_dim=2,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
name=name, mode=mode)
[docs]
class MaxPool3d(_MaxPoolNd):
"""Applies a 1D max pooling over an input signal composed of several input
planes.
Parameters
----------
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = None,
padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = -1,
mode: bm.Mode = None,
name: Optional[str] = None
):
super().__init__(init_value=-jax.numpy.inf,
computation=jax.lax.max,
pool_dim=3,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
name=name, mode=mode)
class _AvgPoolNd(_MaxPoolNd):
def update(self, x):
x = bm.as_jax(x)
x_dim = self.pool_dim + (0 if self.channel_axis is None else 1)
if x.ndim < x_dim:
raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.')
dims = self._infer_shape(x.ndim, self.kernel_size, 1)
stride = self._infer_shape(x.ndim, self.stride, 1)
padding = (self.padding
if isinstance(self.padding, str) else
self._infer_shape(x.ndim, self.padding, element=(0, 0)))
pooled = jax.lax.reduce_window(bm.as_jax(x),
init_value=self.init_value,
computation=self.computation,
window_dimensions=dims,
window_strides=stride,
padding=padding)
if padding == "VALID":
# Avoid the extra reduce_window.
return pooled / np.prod(dims)
else:
# Count the number of valid entries at each input point, then use that for
# computing average. Assumes that any two arrays of same shape will be
# padded the same.
window_counts = jax.lax.reduce_window(jnp.ones_like(bm.as_jax(x)),
init_value=self.init_value,
computation=self.computation,
window_dimensions=dims,
window_strides=stride,
padding=padding)
assert pooled.shape == window_counts.shape
return pooled / window_counts
[docs]
class AvgPool1d(_AvgPoolNd):
"""Applies a 1D average pooling over an input signal composed of several input
planes.
Parameters
----------
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = -1,
mode: bm.Mode = None,
name: Optional[str] = None
):
super().__init__(init_value=0.,
computation=jax.lax.add,
pool_dim=1,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
name=name,
mode=mode)
[docs]
class AvgPool2d(_AvgPoolNd):
"""Applies a 2D average pooling over an input signal composed of several input
planes.
Parameters
----------
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = -1,
mode: bm.Mode = None,
name: Optional[str] = None
):
super().__init__(init_value=0.,
computation=jax.lax.add,
pool_dim=2,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
name=name,
mode=mode)
[docs]
class AvgPool3d(_AvgPoolNd):
"""Applies a 3D average pooling over an input signal composed of several input
planes.
Parameters
----------
kernel_size: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
stride: int, sequence of int
An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`).
padding: str, int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
mode: Mode
The computation mode.
name: optional, str
The object name.
"""
def __init__(
self,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = -1,
mode: bm.Mode = None,
name: Optional[str] = None
):
super().__init__(init_value=0.,
computation=jax.lax.add,
pool_dim=3,
kernel_size=kernel_size,
stride=stride,
padding=padding,
channel_axis=channel_axis,
name=name,
mode=mode)
def _adaptive_pool1d(x, target_size: int, operation: Callable):
"""Adaptive pool 1D.
Args:
x: The input. Should be a JAX array of shape `(dim,)`.
target_size: The shape of the output after the pooling operation `(target_size,)`.
operation: The pooling operation to be performed on the input array.
Returns:
A JAX array of shape `(target_size, )`.
"""
x = bm.as_jax(x)
size = jnp.size(x)
num_head_arrays = size % target_size
num_block = size // target_size
if num_head_arrays != 0:
head_end_index = num_head_arrays * (num_block + 1)
heads = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1))
tails = jax.vmap(operation)(x[head_end_index:].reshape(-1, num_block))
outs = jnp.concatenate([heads, tails])
else:
outs = jax.vmap(operation)(x.reshape(-1, num_block))
return outs
def _generate_vmap(fun: Callable, map_axes: List[int]):
map_axes = sorted(map_axes)
for axis in map_axes:
fun = jax.vmap(fun, in_axes=(axis, None, None), out_axes=axis)
return fun
class AdaptivePool(Layer):
"""General N dimensional adaptive down-sampling to a target shape.
Parameters
----------
target_shape: int, sequence of int
The target output shape.
num_spatial_dims: int
The number of spatial dimensions.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
operation: Callable
The down-sampling operation.
name: str
The class name.
mode: Mode
The computing mode.
"""
def __init__(
self,
target_shape: Union[int, Sequence[int]],
num_spatial_dims: int,
operation: Callable,
channel_axis: Optional[int] = -1,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
super().__init__(name=name, mode=mode)
self.channel_axis = channel_axis
self.operation = operation
if isinstance(target_shape, int):
self.target_shape = (target_shape,) * num_spatial_dims
elif isinstance(target_shape, Sequence) and (len(target_shape) == num_spatial_dims):
self.target_shape = target_shape
else:
raise ValueError("`target_size` must either be an int or tuple of length "
f"{num_spatial_dims} containing ints.")
def update(self, x):
"""Input-output mapping.
Parameters
----------
x: Array
Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)`
or `(..., dim_1, dim_2)`.
"""
x = bm.as_jax(x)
# channel axis
channel_axis = self.channel_axis
if channel_axis:
if not 0 <= abs(channel_axis) < x.ndim:
raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
if channel_axis < 0:
channel_axis = x.ndim + channel_axis
# input dimension
if (x.ndim - (0 if channel_axis is None else 1)) < len(self.target_shape):
raise ValueError(f"Invalid input dimension. Except >={len(self.target_shape)} "
f"dimensions (channel_axis={self.channel_axis}). "
f"But got {x.ndim} dimensions.")
# pooling dimensions
pool_dims = list(range(x.ndim))
if channel_axis:
pool_dims.pop(channel_axis)
# pooling
for i, di in enumerate(pool_dims[-len(self.target_shape):]):
poo_axes = [j for j in range(x.ndim) if j != di]
op = _generate_vmap(_adaptive_pool1d, poo_axes)
x = op(x, self.target_shape[i], self.operation)
return x
[docs]
class AdaptiveAvgPool1d(AdaptivePool):
"""Adaptive one-dimensional average down-sampling.
Parameters
----------
target_shape: int, sequence of int
The target output shape.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
name: str
The class name.
mode: Mode
The computing mode.
"""
def __init__(self,
target_shape: Union[int, Sequence[int]],
channel_axis: Optional[int] = -1,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None):
super().__init__(target_shape,
channel_axis=channel_axis,
num_spatial_dims=1,
operation=jnp.mean,
name=name,
mode=mode)
[docs]
class AdaptiveAvgPool2d(AdaptivePool):
"""Adaptive two-dimensional average down-sampling.
Parameters
----------
target_shape: int, sequence of int
The target output shape.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
name: str
The class name.
mode: Mode
The computing mode.
"""
def __init__(self,
target_shape: Union[int, Sequence[int]],
channel_axis: Optional[int] = -1,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None):
super().__init__(target_shape,
channel_axis=channel_axis,
num_spatial_dims=2,
operation=jnp.mean,
name=name,
mode=mode)
[docs]
class AdaptiveAvgPool3d(AdaptivePool):
"""Adaptive three-dimensional average down-sampling.
Parameters
----------
target_shape: int, sequence of int
The target output shape.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
name: str
The class name.
mode: Mode
The computing mode.
"""
def __init__(self,
target_shape: Union[int, Sequence[int]],
channel_axis: Optional[int] = -1,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None):
super().__init__(target_shape,
channel_axis=channel_axis,
num_spatial_dims=3,
operation=jnp.mean,
name=name,
mode=mode)
[docs]
class AdaptiveMaxPool1d(AdaptivePool):
"""Adaptive one-dimensional maximum down-sampling.
Parameters
----------
target_shape: int, sequence of int
The target output shape.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
name: str
The class name.
mode: Mode
The computing mode.
"""
def __init__(self,
target_shape: Union[int, Sequence[int]],
channel_axis: Optional[int] = -1,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None):
super().__init__(target_shape,
channel_axis=channel_axis,
num_spatial_dims=1,
operation=jnp.max,
name=name,
mode=mode)
[docs]
class AdaptiveMaxPool2d(AdaptivePool):
"""Adaptive two-dimensional maximum down-sampling.
Parameters
----------
target_shape: int, sequence of int
The target output shape.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
name: str
The class name.
mode: Mode
The computing mode.
"""
def __init__(self,
target_shape: Union[int, Sequence[int]],
channel_axis: Optional[int] = -1,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None):
super().__init__(target_shape,
channel_axis=channel_axis,
num_spatial_dims=2,
operation=jnp.max,
name=name,
mode=mode)
[docs]
class AdaptiveMaxPool3d(AdaptivePool):
"""Adaptive three-dimensional maximum down-sampling.
Parameters
----------
target_shape: int, sequence of int
The target output shape.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped.
If ``None``, there is no channel axis.
name: str
The class name.
mode: Mode
The computing mode.
"""
def __init__(self,
target_shape: Union[int, Sequence[int]],
channel_axis: Optional[int] = -1,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None):
super().__init__(target_shape,
channel_axis=channel_axis,
num_spatial_dims=3,
operation=jnp.max,
name=name,
mode=mode)