brainpy.nn.runners.BPTT#

class brainpy.nn.runners.BPTT(target, loss, optimizer=None, max_grad_norm=None, shuffle_data=True, jit=True, **kwargs)[source]#

The trainer implementing back propagation through time (BPTT) algorithm for recurrent neural networks.

__init__(target, loss, optimizer=None, max_grad_norm=None, shuffle_data=True, jit=True, **kwargs)[source]#

Methods

__init__(target, loss[, optimizer, ...])

f_grad([shared_kwargs])

Get gradient function.

f_loss([shared_kwargs])

Get loss function.

f_train([shared_kwargs])

Get training function.

fit(train_data[, test_data, num_batch, ...])

Fit the target model according to the given training and testing data.

predict(xs[, forced_states, ...])

Predict a series of input data with the given target model.

Attributes

mapping_type

Mapping type for the output and the target.

train_losses

Training loss.