Interoperation with other JAX frameworks#

import brainpy.math as bm

BrainPy can be easily interoperated with other JAX frameworks.

1. data are exchangeable in different frameworks.#

This can be realized because JaxArray can be direactly converted to JAX ndarray or NumPy ndarray.

Convert a JaxArray into a JAX ndarray.

# JaxArray.value is a JAX ndarray
DeviceArray([5, 1, 2, 3, 4], dtype=int32)

Convert a JaxArray into a numpy ndarray.

# JaxArray can be easily converted to a numpy ndarray
array([5, 1, 2, 3, 4])

Convert a numpy ndarray into a JaxArray.

JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))

Convert a JAX ndarray into a JaxArray.

import jax.numpy as jnp
JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))
JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))

2. transformations in brainpy.math also work on functions.#

APIs in other JAX frameworks can be naturally integrated in BrainPy. Let’s take the gradient-based optimization library Optax as an example to illustrate how to use other JAX frameworks in BrainPy.

import optax
# First create several useful functions.

network = bm.vmap(lambda params, x:, x), in_axes=(None, 0))

def compute_loss(params, x, y):
  y_pred = network(params, x)
  loss = bm.mean(optax.l2_loss(y_pred, y))
  return loss

def train(params, opt_state, xs, ys):
  grads = bm.grad(compute_loss)(params, xs.value, ys)
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state
# Generate some data

target_params = 0.5
xs = bm.random.normal(size=(16, 2))
ys = bm.sum(xs * target_params, axis=-1)
# Initialize parameters of the model + optimizer

params = bm.array([0.0, 0.0])
optimizer = optax.adam(learning_rate=1e-1)
opt_state = optimizer.init(params)
# A simple update loop

for _ in range(1000):
  params, opt_state = train(params, opt_state, xs, ys)

assert bm.allclose(params, target_params), \
  'Optimization should retrieve the target params used to generate the data.'