get_sharding

Contents

get_sharding#

class brainpy.math.sharding.get_sharding(axis_names=None, mesh=None)[source]#

Get sharding according to the given axes information.

Parameters:
  • axis_names (Optional[Sequence[str]]) – list of str, or tuple of str. The name for each axis in the array.

  • mesh (Optional[Mesh]) – Mesh. The given device mesh.

Return type:

Optional[NamedSharding]

Returns:

The instance of NamedSharding.