brainpy.nn.runners.BPTT
brainpy.nn.runners.BPTT#
- class brainpy.nn.runners.BPTT(target, loss, optimizer=None, max_grad_norm=None, shuffle_data=True, jit=True, **kwargs)[source]#
The trainer implementing back propagation through time (BPTT) algorithm for recurrent neural networks.
- __init__(target, loss, optimizer=None, max_grad_norm=None, shuffle_data=True, jit=True, **kwargs)[source]#
Methods
__init__
(target, loss[, optimizer, ...])f_grad
([shared_kwargs])Get gradient function.
f_loss
([shared_kwargs])Get loss function.
f_train
([shared_kwargs])Get training function.
fit
(train_data[, test_data, num_batch, ...])Fit the target model according to the given training and testing data.
predict
(xs[, forced_states, ...])Predict a series of input data with the given target model.
Attributes
mapping_type
Mapping type for the output and the target.
train_losses
Training loss.