# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Union, Tuple, Optional, Sequence, Callable
import brainstate
from jax import lax
from brainpy import math as bm, tools
from brainpy.dnn.base import Layer
from brainpy.initialize import Initializer, XavierNormal, ZeroInit, parameter
from brainpy.types import ArrayType
__all__ = [
'Conv1d', 'Conv2d', 'Conv3d',
'Conv1D', 'Conv2D', 'Conv3D',
'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d',
]
def to_dimension_numbers(num_spatial_dims: int,
channels_last: bool,
transpose: bool) -> lax.ConvDimensionNumbers:
"""Create a `lax.ConvDimensionNumbers` for the given inputs."""
num_dims = num_spatial_dims + 2
if channels_last:
spatial_dims = tuple(range(1, num_dims - 1))
image_dn = (0, num_dims - 1) + spatial_dims
else:
spatial_dims = tuple(range(2, num_dims))
image_dn = (0, 1) + spatial_dims
if transpose:
kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
else:
kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
return lax.ConvDimensionNumbers(lhs_spec=image_dn,
rhs_spec=kernel_dn,
out_spec=image_dn)
class _GeneralConv(Layer):
"""Apply a convolution to the inputs.
Parameters::
num_spatial_dims: int
The number of spatial dimensions of the input.
in_channels: int
The number of input channels.
out_channels: int
The number of output channels.
kernel_size: int, sequence of int
The shape of the convolutional kernel.
For 1D convolution, the kernel size can be passed as an integer.
For all other cases, it must be a sequence of integers.
stride: int, sequence of int
An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
padding: str, int, sequence of int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
high)` integer pairs that give the padding to apply before and after each
spatial dimension.
lhs_dilation: int, sequence of int
An integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `inputs`
(default: 1). Convolution with input dilation `d` is equivalent to
transposed convolution with stride `d`.
rhs_dilation: int, sequence of int
An integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of the convolution
kernel (default: 1). Convolution with kernel dilation
is also known as 'atrous convolution'.
groups: int
If specified, divides the input features into groups. default 1.
w_initializer: Callable, ArrayType, Initializer
The initializer for the convolutional kernel.
b_initializer: Optional, Callable, ArrayType, Initializer
The initializer for the bias.
mask: ArrayType, Optional
The optional mask of the weights.
mode: Mode
The computation mode of the current object. Default it is `training`.
name: str, Optional
The name of the object.
"""
supported_modes = (bm.TrainingMode, bm.BatchingMode, bm.NonBatchingMode)
def __init__(
self,
num_spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(),
b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(),
mask: Optional[ArrayType] = None,
mode: bm.Mode = None,
name: str = None,
):
super(_GeneralConv, self).__init__(name=name, mode=mode)
self.num_spatial_dims = num_spatial_dims
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = tools.replicate(stride, num_spatial_dims, 'stride')
self.kernel_size = tools.replicate(kernel_size, num_spatial_dims, 'kernel_size')
self.lhs_dilation = tools.replicate(lhs_dilation, num_spatial_dims, 'lhs_dilation')
self.rhs_dilation = tools.replicate(rhs_dilation, num_spatial_dims, 'rhs_dilation')
self.groups = groups
self.w_initializer = w_initializer
self.b_initializer = b_initializer
self.mask = mask
self.dimension_numbers = to_dimension_numbers(num_spatial_dims, channels_last=True, transpose=False)
if isinstance(padding, str):
assert padding in ['SAME', 'VALID']
elif isinstance(padding, int):
padding = tuple((padding, padding) for _ in range(num_spatial_dims))
elif isinstance(padding, (tuple, list)):
if isinstance(padding[0], int):
padding = (padding,) * num_spatial_dims
elif isinstance(padding[0], (tuple, list)):
if len(padding) == 1:
padding = tuple(padding) * num_spatial_dims
else:
if len(padding) != num_spatial_dims:
raise ValueError(f"Padding {padding} must be a Tuple[int, int], "
f"or sequence of Tuple[int, int] with length 1, "
f"or sequence of Tuple[int, int] with length {num_spatial_dims}.")
padding = tuple(padding)
else:
raise ValueError
self.padding = padding
assert self.out_channels % self.groups == 0, '"out_channels" should be divisible by groups'
assert self.in_channels % self.groups == 0, '"in_channels" should be divisible by groups'
kernel_shape = tuple(self.kernel_size) + (self.in_channels // self.groups, self.out_channels)
bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
self.w = parameter(self.w_initializer, kernel_shape, allow_none=False)
self.b = parameter(self.b_initializer, bias_shape, allow_none=True)
if isinstance(self.mode, bm.TrainingMode):
self.w = bm.TrainVar(self.w)
if self.b is not None:
self.b = bm.TrainVar(self.b)
def _check_input_dim(self, x):
if x.ndim != self.num_spatial_dims + 2 and x.ndim != self.num_spatial_dims + 1:
raise ValueError(
f"expected {self.num_spatial_dims + 2}D or {self.num_spatial_dims + 1}D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")
def update(self, x):
self._check_input_dim(x)
nonbatching = False
if x.ndim == self.num_spatial_dims + 1:
nonbatching = True
x = bm.unsqueeze(x, 0)
w = self.w.value
if self.mask is not None:
try:
lax.broadcast_shapes(self.w.shape, self.mask.shape)
except (ValueError, TypeError):
raise ValueError(f"Mask needs to have the same shape as weights. {self.mask.shape} != {self.w.shape}")
w = w * self.mask
y = lax.conv_general_dilated(lhs=bm.as_jax(x),
rhs=bm.as_jax(w),
window_strides=self.stride,
padding=self.padding,
lhs_dilation=self.lhs_dilation,
rhs_dilation=self.rhs_dilation,
feature_group_count=self.groups,
dimension_numbers=self.dimension_numbers)
if nonbatching:
return y[0] if self.b is None else (y + self.b)[0]
else:
return y if self.b is None else (y + self.b)
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, kernel_size={self.kernel_size}, '
f'stride={self.stride}, padding={self.padding}, groups={self.groups})')
[docs]
class Conv1d(_GeneralConv):
"""One-dimensional convolution.
The input should a 2d array with the shape of ``[H, C]``, or
a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size.
Parameters::
in_channels: int
The number of input channels.
out_channels: int
The number of output channels.
kernel_size: int, sequence of int
The shape of the convolutional kernel.
For 1D convolution, the kernel size can be passed as an integer.
For all other cases, it must be a sequence of integers.
strides: int, sequence of int
An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
padding: str, int, sequence of int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
high)` integer pairs that give the padding to apply before and after each
spatial dimension.
lhs_dilation: int, sequence of int
An integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `inputs`
(default: 1). Convolution with input dilation `d` is equivalent to
transposed convolution with stride `d`.
rhs_dilation: int, sequence of int
An integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of the convolution
kernel (default: 1). Convolution with kernel dilation
is also known as 'atrous convolution'.
groups: int
If specified, divides the input features into groups. default 1.
w_initializer: Callable, ArrayType, Initializer
The initializer for the convolutional kernel.
b_initializer: Callable, ArrayType, Initializer
The initializer for the bias.
mask: ArrayType, Optional
The optional mask of the weights.
mode: Mode
The computation mode of the current object. Default it is `training`.
name: str, Optional
The name of the object.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = None,
strides: Union[int, Tuple[int, ...]] = None, # deprecated
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(),
b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(),
mask: Optional[ArrayType] = None,
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
):
if stride is None:
if strides is None:
stride = 1
else:
stride = strides
else:
if strides is not None:
raise ValueError('Cannot provide "stride" and "strides" both.')
super(Conv1d, self).__init__(num_spatial_dims=1,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
groups=groups,
w_initializer=w_initializer,
b_initializer=b_initializer,
mask=mask,
mode=mode,
name=name)
def _check_input_dim(self, x):
if x.ndim != 3 and x.ndim != 2:
raise ValueError(f"expected 3D or 2D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")
[docs]
class Conv2d(_GeneralConv):
"""Two-dimensional convolution.
The input should a 3d array with the shape of ``[H, W, C]``, or
a 4d array with the shape of ``[B, H, W, C]``.
Parameters::
in_channels: int
The number of input channels.
out_channels: int
The number of output channels.
kernel_size: int, sequence of int
The shape of the convolutional kernel.
For 1D convolution, the kernel size can be passed as an integer.
For all other cases, it must be a sequence of integers.
stride: int, sequence of int
An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
padding: str, int, sequence of int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
high)` integer pairs that give the padding to apply before and after each
spatial dimension.
lhs_dilation: int, sequence of int
An integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `inputs`
(default: 1). Convolution with input dilation `d` is equivalent to
transposed convolution with stride `d`.
rhs_dilation: int, sequence of int
An integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of the convolution
kernel (default: 1). Convolution with kernel dilation
is also known as 'atrous convolution'.
groups: int
If specified, divides the input features into groups. default 1.
w_initializer: Callable, ArrayType, Initializer
The initializer for the convolutional kernel.
b_initializer: Callable, ArrayType, Initializer
The initializer for the bias.
mask: ArrayType, Optional
The optional mask of the weights.
mode: Mode
The computation mode of the current object. Default it is `training`.
name: str, Optional
The name of the object.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = None,
strides: Union[int, Tuple[int, ...]] = None, # deprecated
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(),
b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(),
mask: Optional[ArrayType] = None,
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
):
if stride is None:
if strides is None:
stride = 1
else:
stride = strides
else:
if strides is not None:
raise ValueError('Cannot provide "stride" and "strides" both.')
super(Conv2d, self).__init__(num_spatial_dims=2,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
groups=groups,
w_initializer=w_initializer,
b_initializer=b_initializer,
mask=mask,
mode=mode,
name=name)
def _check_input_dim(self, x):
if x.ndim != 4 and x.ndim != 3:
raise ValueError(f"expected 4D or 3D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")
[docs]
class Conv3d(_GeneralConv):
"""Three-dimensional convolution.
The input should a 3d array with the shape of ``[H, W, D, C]``, or
a 4d array with the shape of ``[B, H, W, D, C]``.
Parameters::
in_channels: int
The number of input channels.
out_channels: int
The number of output channels.
kernel_size: int, sequence of int
The shape of the convolutional kernel.
For 1D convolution, the kernel size can be passed as an integer.
For all other cases, it must be a sequence of integers.
stride: int, sequence of int
An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
padding: str, int, sequence of int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
high)` integer pairs that give the padding to apply before and after each
spatial dimension.
lhs_dilation: int, sequence of int
An integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `inputs`
(default: 1). Convolution with input dilation `d` is equivalent to
transposed convolution with stride `d`.
rhs_dilation: int, sequence of int
An integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of the convolution
kernel (default: 1). Convolution with kernel dilation
is also known as 'atrous convolution'.
groups: int
If specified, divides the input features into groups. default 1.
w_initializer: Callable, ArrayType, Initializer
The initializer for the convolutional kernel.
b_initializer: Callable, ArrayType, Initializer
The initializer for the bias.
mask: ArrayType, Optional
The optional mask of the weights.
mode: Mode
The computation mode of the current object. Default it is `training`.
name: str, Optional
The name of the object.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = None,
strides: Union[int, Tuple[int, ...]] = None, # deprecated
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(),
b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(),
mask: Optional[ArrayType] = None,
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
):
if stride is None:
if strides is None:
stride = 1
else:
stride = strides
else:
if strides is not None:
raise ValueError('Cannot provide "stride" and "strides" both.')
super(Conv3d, self).__init__(num_spatial_dims=3,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
groups=groups,
w_initializer=w_initializer,
b_initializer=b_initializer,
mask=mask,
mode=mode,
name=name)
def _check_input_dim(self, x):
if x.ndim != 5 and x.ndim != 4:
raise ValueError(f"expected 5D or 4D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")
Conv1D = Conv1d
Conv2D = Conv2d
Conv3D = Conv3d
class _GeneralConvTranspose(Layer):
supported_modes = (bm.TrainingMode, bm.BatchingMode, bm.NonBatchingMode)
def __init__(
self,
num_spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(in_axis=-1, out_axis=-2),
b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(),
mask: Optional[ArrayType] = None,
precision: Optional[lax.Precision] = None,
mode: bm.Mode = None,
name: str = None,
):
super().__init__(name=name, mode=mode)
self.num_spatial_dims = num_spatial_dims
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = tools.replicate(stride, num_spatial_dims, 'stride')
self.kernel_size = tools.replicate(kernel_size, num_spatial_dims, 'kernel_size')
self.w_initializer = w_initializer
self.b_initializer = b_initializer
self.precision = precision
self.mask = mask
self.dimension_numbers = to_dimension_numbers(num_spatial_dims, channels_last=True, transpose=False)
if isinstance(padding, str):
assert padding in ['SAME', 'VALID']
elif isinstance(padding, int):
padding = tuple((padding, padding) for _ in range(num_spatial_dims))
elif isinstance(padding, (tuple, list)):
if isinstance(padding[0], int):
padding = (padding,) * num_spatial_dims
elif isinstance(padding[0], (tuple, list)):
if len(padding) == 1:
padding = tuple(padding) * num_spatial_dims
else:
if len(padding) != num_spatial_dims:
raise ValueError(f"Padding {padding} must be a Tuple[int, int], "
f"or sequence of Tuple[int, int] with length 1, "
f"or sequence of Tuple[int, int] with length {num_spatial_dims}.")
padding = tuple(padding)
else:
raise ValueError
self.padding = padding
kernel_shape = tuple(self.kernel_size) + (self.in_channels, self.out_channels)
bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
self.w = parameter(self.w_initializer, kernel_shape, allow_none=False)
self.b = parameter(self.b_initializer, bias_shape, allow_none=True)
if isinstance(self.mode, bm.TrainingMode):
self.w = bm.TrainVar(self.w)
if self.b is not None:
self.b = bm.TrainVar(self.b)
def _check_input_dim(self, x):
raise NotImplementedError
def update(self, x):
self._check_input_dim(x)
nonbatching = False
if x.ndim == self.num_spatial_dims + 1:
nonbatching = True
x = bm.unsqueeze(x, 0)
w = brainstate.maybe_state(self.w)
if self.mask is not None:
try:
lax.broadcast_shapes(self.w.shape, self.mask.shape)
except (ValueError, TypeError):
raise ValueError(f"Mask needs to have the same shape as weights. {self.mask.shape} != {self.w.shape}")
w = w * self.mask
y = lax.conv_transpose(lhs=bm.as_jax(x),
rhs=bm.as_jax(w),
strides=self.stride,
padding=self.padding,
precision=self.precision,
rhs_dilation=None,
dimension_numbers=self.dimension_numbers)
if nonbatching:
return y[0] if self.b is None else (y + self.b)[0]
else:
return y if self.b is None else (y + self.b)
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, kernel_size={self.kernel_size}, '
f'stride={self.stride}, padding={self.padding})')
[docs]
class ConvTranspose1d(_GeneralConvTranspose):
"""One dimensional transposed convolution (aka. deconvolution)."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(in_axis=-1, out_axis=-2),
b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(),
mask: Optional[ArrayType] = None,
precision: Optional[lax.Precision] = None,
mode: bm.Mode = None,
name: str = None,
):
"""Initializes the module.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
kernel_size: The shape of the kernel. Either an integer or a sequence of
length 1.
stride: Optional stride for the kernel. Either an integer or a sequence of
length 1. Defaults to 1.
padding: Optional padding algorithm. Either ``VALID`` or ``SAME``.
Defaults to ``SAME``. See:
https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
w_initializer: Optional weight initialization. By default, truncated normal.
b_initializer: Optional bias initialization. By default, zeros.
mask: Optional mask of the weights.
name: The name of the module.
"""
super().__init__(
num_spatial_dims=1,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
w_initializer=w_initializer,
b_initializer=b_initializer,
precision=precision,
mode=mode,
mask=mask,
name=name
)
def _check_input_dim(self, x):
if x.ndim != 3 and x.ndim != 2:
raise ValueError(f"expected 3D or 2D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")
[docs]
class ConvTranspose2d(_GeneralConvTranspose):
"""Two dimensional transposed convolution (aka. deconvolution)."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(in_axis=-1, out_axis=-2),
b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(),
mask: Optional[ArrayType] = None,
precision: Optional[lax.Precision] = None,
mode: bm.Mode = None,
name: str = None,
):
"""Initializes the module.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
kernel_size: The shape of the kernel. Either an integer or a sequence of
length 2.
stride: Optional stride for the kernel. Either an integer or a sequence of
length 2. Defaults to 1.
padding: Optional padding algorithm. Either ``VALID`` or ``SAME``.
Defaults to ``SAME``. See:
https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
w_initializer: Optional weight initialization. By default, truncated normal.
b_initializer: Optional bias initialization. By default, zeros.
mask: Optional mask of the weights.
name: The name of the module.
"""
super().__init__(
num_spatial_dims=2,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
w_initializer=w_initializer,
b_initializer=b_initializer,
precision=precision,
mode=mode,
mask=mask,
name=name
)
def _check_input_dim(self, x):
if x.ndim != 4 and x.ndim != 3:
raise ValueError(f"expected 4D or 3D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")
[docs]
class ConvTranspose3d(_GeneralConvTranspose):
"""Three dimensional transposed convolution (aka. deconvolution)."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(in_axis=-1, out_axis=-2),
b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(),
mask: Optional[ArrayType] = None,
precision: Optional[lax.Precision] = None,
mode: bm.Mode = None,
name: str = None,
):
"""Initializes the module.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
kernel_size: The shape of the kernel. Either an integer or a sequence of
length 3.
stride: Optional stride for the kernel. Either an integer or a sequence of
length 3. Defaults to 1.
padding: Optional padding algorithm. Either ``VALID`` or ``SAME``.
Defaults to ``SAME``. See:
https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
w_initializer: Optional weight initialization. By default, truncated normal.
b_initializer: Optional bias initialization. By default, zeros.
mask: Optional mask of the weights.
name: The name of the module.
"""
super().__init__(
num_spatial_dims=3,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
w_initializer=w_initializer,
b_initializer=b_initializer,
precision=precision,
mode=mode,
mask=mask,
name=name
)
def _check_input_dim(self, x):
if x.ndim != 5 and x.ndim != 4:
raise ValueError(f"expected 5D or 4D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")