partition_by_sharding

partition_by_sharding#

class brainpy.math.sharding.partition_by_sharding(x, sharding=None)[source]#

Partition inputs with the given sharding strategy.

Parameters:
  • x (Any) – The input arrays. It can be a pyTree of arrays.

  • sharding (Optional[Sharding]) – The jax.sharding.Sharding instance.

Returns:

The sharded x, which has been partitioned by the given sharding stragety.