partition_by_sharding

partition_by_sharding#

class brainpy.math.sharding.partition_by_sharding(x, sharding=None)[source]#

Partition inputs with the given sharding strategy.

Parameters:
  • x (Any) – The input arrays. It can be a pyTree of arrays.

  • sharding (Optional[Sharding]) – The jax.sharding.Sharding instance.

Return type:

The sharded ``x``, which has been partitioned by the given sharding stragety.