# -*- coding: utf-8 -*-
import warnings
from typing import Union, Callable, Sequence, Optional, Tuple
import jax.numpy as jnp
import brainpy.math as bm
from brainpy.math import activations
from brainpy._src.dnn.base import Layer
from brainpy.check import (is_integer,
is_initializer)
from brainpy.initialize import (XavierNormal,
ZeroInit,
Orthogonal,
parameter,
variable,
variable_,
Initializer)
from brainpy.types import ArrayType
from brainpy._src.dnn.conv import _GeneralConv
__all__ = [
'RNNCell', 'GRUCell', 'LSTMCell',
'Conv1dLSTMCell', 'Conv2dLSTMCell', 'Conv3dLSTMCell',
]
[docs]
class RNNCell(Layer):
r"""Basic fully-connected RNN core.
Given :math:`x_t` and the previous hidden state :math:`h_{t-1}` the
core computes
.. math::
h_t = \mathrm{ReLU}(w_i x_t + b_i + w_h h_{t-1} + b_h)
The output is equal to the new state, :math:`h_t`.
Parameters
----------
num_in: int
The dimension of the input vector
num_out: int
The number of hidden unit in the node.
state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The state initializer.
Wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The input weight initializer.
Wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The hidden weight initializer.
b_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray
The bias weight initializer.
activation: str, callable
The activation function. It can be a string or a callable function.
See ``brainpy.math.activations`` for more details.
"""
def __init__(
self,
num_in: int,
num_out: int,
state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(),
Wi_initializer: Union[ArrayType, Callable, Initializer] = XavierNormal(),
Wh_initializer: Union[ArrayType, Callable, Initializer] = XavierNormal(),
b_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(),
activation: str = 'relu',
mode: bm.Mode = None,
train_state: bool = False,
name: str = None,
):
super(RNNCell, self).__init__(mode=mode, name=name)
# parameters
self._state_initializer = state_initializer
is_initializer(state_initializer, 'state_initializer', allow_none=False)
self.num_out = num_out
is_integer(num_out, 'num_out', min_bound=1, allow_none=False)
self.train_state = train_state
# parameters
self.num_in = num_in
is_integer(num_in, 'num_in', min_bound=1, allow_none=False)
# initializers
self._Wi_initializer = Wi_initializer
self._Wh_initializer = Wh_initializer
self._b_initializer = b_initializer
is_initializer(Wi_initializer, 'wi_initializer', allow_none=False)
is_initializer(Wh_initializer, 'wh_initializer', allow_none=False)
is_initializer(b_initializer, 'b_initializer', allow_none=True)
# activation function
self.activation = getattr(activations, activation)
# weights
self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out))
self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out))
self.b = parameter(self._b_initializer, (self.num_out,))
if isinstance(self.mode, bm.TrainingMode):
self.Wi = bm.TrainVar(self.Wi)
self.Wh = bm.TrainVar(self.Wh)
self.b = None if (self.b is None) else bm.TrainVar(self.b)
# state
self.state = variable(jnp.zeros, self.mode, self.num_out)
if train_state and isinstance(self.mode, bm.TrainingMode):
self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out,), allow_none=False))
self.state[:] = self.state2train
def reset_state(self, batch_or_mode=None, **kwargs):
self.state.value = parameter(self._state_initializer, (batch_or_mode, self.num_out,), allow_none=False)
if self.train_state:
self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False)
self.state[:] = self.state2train
[docs]
def update(self, x):
h = x @ self.Wi
h += self.state.value @ self.Wh
if self.b is not None:
h += self.b
self.state.value = self.activation(h)
return self.state.value
[docs]
class GRUCell(Layer):
r"""Gated Recurrent Unit.
The implementation is based on (Chung, et al., 2014) [1]_ with biases.
Given :math:`x_t` and the previous state :math:`h_{t-1}` the core computes
.. math::
\begin{array}{ll}
z_t &= \sigma(W_{iz} x_t + W_{hz} h_{t-1} + b_z) \\
r_t &= \sigma(W_{ir} x_t + W_{hr} h_{t-1} + b_r) \\
a_t &= \tanh(W_{ia} x_t + W_{ha} (r_t \bigodot h_{t-1}) + b_a) \\
h_t &= (1 - z_t) \bigodot h_{t-1} + z_t \bigodot a_t
\end{array}
where :math:`z_t` and :math:`r_t` are reset and update gates.
The output is equal to the new hidden state, :math:`h_t`.
Warning: Backwards compatibility of GRU weights is currently unsupported.
Parameters
----------
num_in: int
The dimension of the input vector
num_out: int
The number of hidden unit in the node.
state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The state initializer.
Wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The input weight initializer.
Wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The hidden weight initializer.
b_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray
The bias weight initializer.
activation: str, callable
The activation function. It can be a string or a callable function.
See ``brainpy.math.activations`` for more details.
References
----------
.. [1] Chung, J., Gulcehre, C., Cho, K. and Bengio, Y., 2014. Empirical
evaluation of gated recurrent neural networks on sequence modeling.
arXiv preprint arXiv:1412.3555.
"""
def __init__(
self,
num_in: int,
num_out: int,
Wi_initializer: Union[ArrayType, Callable, Initializer] = Orthogonal(),
Wh_initializer: Union[ArrayType, Callable, Initializer] = Orthogonal(),
b_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(),
state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(),
activation: str = 'tanh',
mode: bm.Mode = None,
train_state: bool = False,
name: str = None,
):
super(GRUCell, self).__init__(mode=mode, name=name)
# parameters
self._state_initializer = state_initializer
is_initializer(state_initializer, 'state_initializer', allow_none=False)
self.num_out = num_out
is_integer(num_out, 'num_out', min_bound=1, allow_none=False)
self.train_state = train_state
self.num_in = num_in
is_integer(num_in, 'num_in', min_bound=1, allow_none=False)
# initializers
self._Wi_initializer = Wi_initializer
self._Wh_initializer = Wh_initializer
self._b_initializer = b_initializer
is_initializer(Wi_initializer, 'Wi_initializer', allow_none=False)
is_initializer(Wh_initializer, 'Wh_initializer', allow_none=False)
is_initializer(b_initializer, 'b_initializer', allow_none=True)
# activation function
self.activation = getattr(activations, activation)
# weights
self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out * 3), allow_none=False)
self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out * 3), allow_none=False)
self.b = parameter(self._b_initializer, (self.num_out * 3,))
if isinstance(self.mode, bm.TrainingMode):
self.Wi = bm.TrainVar(self.Wi)
self.Wh = bm.TrainVar(self.Wh)
self.b = bm.TrainVar(self.b) if (self.b is not None) else None
# state
self.state = variable(jnp.zeros, self.mode, self.num_out)
if train_state and isinstance(self.mode, bm.TrainingMode):
self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out,), allow_none=False))
self.state[:] = self.state2train
def reset_state(self, batch_or_mode=None, **kwargs):
self.state.value = parameter(self._state_initializer, (batch_or_mode, self.num_out), allow_none=False)
if self.train_state:
self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False)
self.state[:] = self.state2train
[docs]
def update(self, x):
gates_x = jnp.matmul(x, bm.as_jax(self.Wi))
zr_x, a_x = jnp.split(gates_x, indices_or_sections=[2 * self.num_out], axis=-1)
w_h_z, w_h_a = jnp.split(bm.as_jax(self.Wh), indices_or_sections=[2 * self.num_out], axis=-1)
zr_h = jnp.matmul(bm.as_jax(self.state), w_h_z)
zr = zr_x + zr_h
has_bias = (self.b is not None)
if has_bias:
b_z, b_a = jnp.split(bm.as_jax(self.b), indices_or_sections=[2 * self.num_out], axis=0)
zr += jnp.broadcast_to(b_z, zr_h.shape)
z, r = jnp.split(bm.sigmoid(zr), indices_or_sections=2, axis=-1)
a_h = jnp.matmul(r * self.state, w_h_a)
if has_bias:
a = self.activation(a_x + a_h + jnp.broadcast_to(b_a, a_h.shape))
else:
a = self.activation(a_x + a_h)
next_state = (1 - z) * self.state + z * a
self.state.value = next_state
return self.state.value
[docs]
class LSTMCell(Layer):
r"""Long short-term memory (LSTM) RNN core.
The implementation is based on (zaremba, et al., 2014) [1]_. Given
:math:`x_t` and the previous state :math:`(h_{t-1}, c_{t-1})` the core
computes
.. math::
\begin{array}{ll}
i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\
f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\
g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\
o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\
c_t = f_t c_{t-1} + i_t g_t \\
h_t = o_t \tanh(c_t)
\end{array}
where :math:`i_t`, :math:`f_t`, :math:`o_t` are input, forget and
output gate activations, and :math:`g_t` is a vector of cell updates.
The output is equal to the new hidden, :math:`h_t`.
Notes
-----
Forget gate initialization: Following (Jozefowicz, et al., 2015) [2]_ we add 1.0
to :math:`b_f` after initialization in order to reduce the scale of forgetting in
the beginning of the training.
Parameters
----------
num_in: int
The dimension of the input vector
num_out: int
The number of hidden unit in the node.
state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The state initializer.
Wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The input weight initializer.
Wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray
The hidden weight initializer.
b_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray
The bias weight initializer.
activation: str, callable
The activation function. It can be a string or a callable function.
See ``brainpy.math.activations`` for more details.
References
----------
.. [1] Zaremba, Wojciech, Ilya Sutskever, and Oriol Vinyals. "Recurrent neural
network regularization." arXiv preprint arXiv:1409.2329 (2014).
.. [2] Jozefowicz, Rafal, Wojciech Zaremba, and Ilya Sutskever. "An empirical
exploration of recurrent network architectures." In International conference
on machine learning, pp. 2342-2350. PMLR, 2015.
"""
def __init__(
self,
num_in: int,
num_out: int,
Wi_initializer: Union[ArrayType, Callable, Initializer] = XavierNormal(),
Wh_initializer: Union[ArrayType, Callable, Initializer] = XavierNormal(),
b_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(),
state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(),
activation: str = 'tanh',
mode: bm.Mode = None,
train_state: bool = False,
name: str = None,
):
super(LSTMCell, self).__init__(mode=mode, name=name)
# parameters
self._state_initializer = state_initializer
is_initializer(state_initializer, 'state_initializer', allow_none=False)
self.num_out = num_out
is_integer(num_out, 'num_out', min_bound=1, allow_none=False)
self.train_state = train_state
self.num_in = num_in
is_integer(num_in, 'num_in', min_bound=1, allow_none=False)
# initializers
self._state_initializer = state_initializer
self._Wi_initializer = Wi_initializer
self._Wh_initializer = Wh_initializer
self._b_initializer = b_initializer
is_initializer(Wi_initializer, 'wi_initializer', allow_none=False)
is_initializer(Wh_initializer, 'wh_initializer', allow_none=False)
is_initializer(b_initializer, 'b_initializer', allow_none=True)
is_initializer(state_initializer, 'state_initializer', allow_none=False)
# activation function
self.activation = getattr(activations, activation)
# weights
self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out * 4))
self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out * 4))
self.b = parameter(self._b_initializer, (self.num_out * 4,))
if isinstance(self.mode, bm.TrainingMode):
self.Wi = bm.TrainVar(self.Wi)
self.Wh = bm.TrainVar(self.Wh)
self.b = None if (self.b is None) else bm.TrainVar(self.b)
# state
self.state = variable(jnp.zeros, self.mode, self.num_out * 2)
if train_state and isinstance(self.mode, bm.TrainingMode):
self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out * 2,), allow_none=False))
self.state[:] = self.state2train
def reset_state(self, batch_or_mode=None, **kwargs):
self.state.value = parameter(self._state_initializer, (batch_or_mode, self.num_out * 2), allow_none=False)
if self.train_state:
self.state2train.value = parameter(self._state_initializer, self.num_out * 2, allow_none=False)
self.state[:] = self.state2train
[docs]
def update(self, x):
h, c = bm.split(self.state.value, 2, axis=-1)
gated = x @ self.Wi
if self.b is not None:
gated += self.b
gated += h @ self.Wh
i, g, f, o = bm.split(gated, indices_or_sections=4, axis=-1)
c = bm.sigmoid(f + 1.) * c + bm.sigmoid(i) * self.activation(g)
h = bm.sigmoid(o) * self.activation(c)
self.state.value = bm.concatenate([h, c], axis=-1)
return h
@property
def h(self):
"""Hidden state."""
return jnp.split(self.state.value, 2, axis=-1)[0]
@h.setter
def h(self, value):
if self.state is None:
raise ValueError('Cannot set "h" state. Because the state is not initialized.')
self.state[:self.state.shape[0] // 2, :] = value
@property
def c(self):
"""Memory cell."""
return jnp.split(self.state.value, 2, axis=-1)[1]
@c.setter
def c(self, value):
if self.state is None:
raise ValueError('Cannot set "c" state. Because the state is not initialized.')
self.state[self.state.shape[0] // 2:, :] = value
class _ConvNDLSTMCell(Layer):
r"""``num_spatial_dims``-D convolutional LSTM.
The implementation is based on :cite:`xingjian2015convolutional`.
Given :math:`x_t` and the previous state :math:`(h_{t-1}, c_{t-1})`
the core computes
.. math::
\begin{array}{ll}
i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\
f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\
g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\
o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\
c_t = f_t c_{t-1} + i_t g_t \\
h_t = o_t \tanh(c_t)
\end{array}
where :math:`*` denotes the convolution operator; :math:`i_t`,
:math:`f_t`, :math:`o_t` are input, forget and output gate activations,
and :math:`g_t` is a vector of cell updates.
The output is equal to the new hidden state, :math:`h_t`.
Notes:
Forget gate initialization:
Following :cite:`jozefowicz2015empirical` we add 1.0 to :math:`b_f`
after initialization in order to reduce the scale of forgetting in
the beginning of the training.
"""
def __init__(
self,
input_shape: Tuple[int, ...],
# convolution parameters
num_spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(),
b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(),
# recurrent parameters
state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(),
train_state: bool = False,
# others
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
"""Constructs a convolutional LSTM.
Args:
num_spatial_dims: Number of spatial dimensions of the input.
input_shape: Shape of the inputs excluding batch size.
out_channels: Number of output channels.
kernel_size: Sequence of kernel sizes (of length ``num_spatial_dims``),
or an int. ``kernel_shape`` will be expanded to define a kernel size in
all dimensions.
name: Name of the module.
"""
super().__init__(name=name, mode=mode)
# parameters
self._state_initializer = state_initializer
is_initializer(state_initializer, 'state_initializer', allow_none=False)
self.train_state = train_state
self.num_spatial_dims = num_spatial_dims
self.in_channels = in_channels
self.out_channels = out_channels
self.input_shape = tuple(input_shape)
self.input_to_hidden = _GeneralConv(num_spatial_dims=num_spatial_dims,
in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=kernel_size,
stride=stride,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
groups=groups,
w_initializer=w_initializer,
b_initializer=b_initializer,
mode=mode)
self.hidden_to_hidden = _GeneralConv(num_spatial_dims=num_spatial_dims,
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=kernel_size,
stride=stride,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
groups=groups,
w_initializer=w_initializer,
b_initializer=b_initializer,
mode=mode)
self.reset_state()
def reset_state(self, batch_or_mode: int = 1, **kwargs):
if self.mode.is_a(bm.NonBatchingMode):
shape = self.input_shape + (self.out_channels,)
self.h = variable_(self._state_initializer, shape)
self.c = variable_(self._state_initializer, shape)
else:
shape = self.input_shape + (self.out_channels,)
self.h = variable_(self._state_initializer, shape, batch_or_mode)
self.c = variable_(self._state_initializer, shape, batch_or_mode)
self.c = variable_(self.c, batch_axis=0)
if self.mode.is_a(bm.TrainingMode) and self.train_state:
h_to_train = parameter(self._state_initializer, shape, allow_none=False)
c_to_train = parameter(self._state_initializer, shape, allow_none=False)
self.h_to_train = bm.TrainVar(h_to_train)
self.c_to_train = bm.TrainVar(c_to_train)
self.h[:] = self.h_to_train
self.c[:] = self.c_to_train
def update(self, x):
gates = self.input_to_hidden(x) + self.hidden_to_hidden(self.h)
i, g, f, o = bm.split(gates, indices_or_sections=4, axis=-1)
f = bm.sigmoid(f + 1)
c = f * self.c + bm.sigmoid(i) * bm.tanh(g)
h = bm.sigmoid(o) * bm.tanh(c)
self.h.value = h
self.c.value = c
return h
[docs]
class Conv1dLSTMCell(_ConvNDLSTMCell): # pylint: disable=empty-docstring
__doc__ = _ConvNDLSTMCell.__doc__.replace("``num_spatial_dims``", "1")
def __init__(
self,
input_shape: Tuple[int, ...],
# convolution parameters
in_channels: int,
out_channels: int,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(),
b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(),
# recurrent parameters
state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(),
train_state: bool = False,
# others
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
"""Constructs a 1-D convolutional LSTM.
Input: [Batch_Size, Input_Data_Size, Input_Channel_Size]
Output: [Batch_Size, Output_Data_Size, Output_Channel_Size]
Args:
input_shape: Shape of the inputs excluding batch size.
out_channels: Number of output channels.
kernel_size: Sequence of kernel sizes (of length 1), or an int.
``kernel_shape`` will be expanded to define a kernel size in all
dimensions.
name: Name of the module.
"""
super().__init__(
num_spatial_dims=1,
input_shape=input_shape,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
groups=groups,
w_initializer=w_initializer,
b_initializer=b_initializer,
state_initializer=state_initializer,
train_state=train_state,
mode=mode,
name=name
)
[docs]
class Conv2dLSTMCell(_ConvNDLSTMCell): # pylint: disable=empty-docstring
__doc__ = _ConvNDLSTMCell.__doc__.replace("``num_spatial_dims``", "2")
def __init__(
self,
input_shape: Tuple[int, ...],
# convolution parameters
in_channels: int,
out_channels: int,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(),
b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(),
# recurrent parameters
state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(),
train_state: bool = False,
# others
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
"""Constructs a 2-D convolutional LSTM.
Input: [Batch_Size, Input_Data_Size_Dim1,Input_Data_Size_Dim2, Input_Channel_Size]
Output: [Batch_Size, Output_Data_Size_Dim1,Output_Data_Size_Dim2 , Output_Channel_Size]
Args:
input_shape: Shape of the inputs excluding batch size.
out_channels: Number of output channels.
kernel_size: Sequence of kernel sizes (of length 2), or an int.
``kernel_shape`` will be expanded to define a kernel size in all
dimensions.
name: Name of the module.
"""
super().__init__(
num_spatial_dims=2,
input_shape=input_shape,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
groups=groups,
w_initializer=w_initializer,
b_initializer=b_initializer,
state_initializer=state_initializer,
train_state=train_state,
mode=mode,
name=name
)
[docs]
class Conv3dLSTMCell(_ConvNDLSTMCell): # pylint: disable=empty-docstring
__doc__ = _ConvNDLSTMCell.__doc__.replace("``num_spatial_dims``", "3")
def __init__(
self,
input_shape: Tuple[int, ...],
# convolution parameters
in_channels: int,
out_channels: int,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(),
b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(),
# recurrent parameters
state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(),
train_state: bool = False,
# others
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
"""Constructs a 3-D convolutional LSTM.
Input: [Batch_Size, Input_Data_Size_Dim1,Input_Data_Size_Dim2,Input_Data_Size_Dim3 ,Input_Channel_Size]
Output: [Batch_Size, Output_Data_Size_Dim1,Output_Data_Size_Dim2,Output_Data_Size_Dim3,Output_Channel_Size]
Args:
input_shape: Shape of the inputs excluding batch size.
out_channels: Number of output channels.
kernel_size: Sequence of kernel sizes (of length 3), or an int.
``kernel_shape`` will be expanded to define a kernel size in all
dimensions.
name: Name of the module.
"""
super().__init__(
num_spatial_dims=3,
input_shape=input_shape,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
groups=groups,
w_initializer=w_initializer,
b_initializer=b_initializer,
state_initializer=state_initializer,
train_state=train_state,
mode=mode,
name=name
)