partition_by_axname

partition_by_axname#

class brainpy.math.sharding.partition_by_axname(x, axis_names=None, mesh=None)[source]#

Put the given arrays into the mesh devices.

Parameters:
  • x (Any) – any. Any array.

  • axis_names (Optional[Sequence[str]]) – sequence of str. The name for each axis in the array.

  • mesh (Optional[Mesh]) – Mesh. The given device mesh.

Returns:

The re-sharded arrays.