- brainpy.math.operators.segment_sum(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False, bucket_size=None, mode=None)#
segment_sumoperator for brainpy JaxArray and Variable.
data (Array) – An array with the values to be reduced.
segment_ids (Array) – An array with integer dtype that indicates the segments of data (along its leading axis) to be summed. Values can be repeated and need not be sorted.
num_segments (Optional, int) – An int with nonnegative value indicating the number of segments. The default is set to be the minimum number of segments that would support all indices in
segment_ids, calculated as
max(segment_ids) + 1. Since num_segments determines the size of the output, a static value must be provided to use
indices_are_sorted (bool) – whether
segment_idsis known to be sorted.
unique_indices (bool) – whether segment_ids is known to be free of duplicates.
bucket_size (int) – Size of bucket to group indices into.
segment_sumis performed on each bucket separately to improve numerical stability of addition. Default
Nonemeans no bucketing.
output – An array with shape
(num_segments,) + data.shape[1:]representing the segment sums.
- Return type