brainpy.losses.ctc_loss_with_forward_probs#

brainpy.losses.ctc_loss_with_forward_probs(logits, logit_paddings, labels, label_paddings, blank_id=0, log_epsilon=- 100000.0)[source]#

Computes CTC loss and CTC forward-probabilities. The CTC loss is a loss function based on log-likelihoods of the model that introduces a special blank symbol \(\phi\) to represent variable-length output sequences. Forward probabilities returned by this function, as auxiliary results, are grouped into two part: blank alpha-probability and non-blank alpha probability. Those are defined as follows: .. math:

\alpha_{\mathrm{BLANK}}(t, n) =
\sum_{\pi_{1:t-1}} p(\pi_t = \phi | \pi_{1:t-1}, y_{1:n-1}, \cdots), \\
\alpha_{\mathrm{LABEL}}(t, n) =
\sum_{\pi_{1:t-1}} p(\pi_t = y_n | \pi_{1:t-1}, y_{1:n-1}, \cdots).

Here, \(\pi\) denotes the alignment sequence in the reference [Graves et al, 2006] that is blank-inserted representations of labels. The return values are the logarithms of the above probabilities. .. rubric:: References

[Graves et al, 2006](https://dl.acm.org/doi/abs/10.1145/1143844.1143891)

Parameters
  • logits (TypeVar(Array, JaxArray, Variable, TrainVar, Array, ndarray)) – (B, T, K)-array containing logits of each class where B denotes the batch size, T denotes the max time frames in logits, and K denotes the number of classes including a class for blanks.

  • logit_paddings (TypeVar(Array, JaxArray, Variable, TrainVar, Array, ndarray)) – (B, T)-array. Padding indicators for logits. Each element must be either 1.0 or 0.0, and logitpaddings[b, t] == 1.0 denotes that logits[b, t, :] are padded values.

  • labels (TypeVar(Array, JaxArray, Variable, TrainVar, Array, ndarray)) – (B, N)-array containing reference integer labels where N denotes the max time frames in the label sequence.

  • label_paddings (TypeVar(Array, JaxArray, Variable, TrainVar, Array, ndarray)) – (B, N)-array. Padding indicators for labels. Each element must be either 1.0 or 0.0, and labelpaddings[b, n] == 1.0 denotes that labels[b, n] is a padded label. In the current implementation, labels must be right-padded, i.e. each row labelpaddings[b, :] must be repetition of zeroes, followed by repetition of ones.

  • blank_id (int) – Id for blank token. logits[b, :, blank_id] are used as probabilities of blank symbols.

  • log_epsilon (float) – Numerically-stable approximation of log(+0).

Return type

Tuple[TypeVar(Array, JaxArray, Variable, TrainVar, Array, ndarray), TypeVar(Array, JaxArray, Variable, TrainVar, Array, ndarray), TypeVar(Array, JaxArray, Variable, TrainVar, Array, ndarray)]

Returns

A tuple (loss_value, logalpha_blank, logalpha_nonblank). Here, loss_value is a (B,)-array containing the loss values for each sequence in the batch, logalpha_blank and logalpha_nonblank are (T, B, N+1)-arrays where the (t, b, n)-th element denotes log alpha_B(t, n) and log alpha_L(t, n), respectively, for b-th sequence in the batch.