Interoperation with other JAX frameworks
Contents
Interoperation with other JAX frameworks#
import brainpy.math as bm
BrainPy can be easily interoperated with other JAX frameworks.
1. data are exchangeable in different frameworks.#
This can be realized because JaxArray
can be direactly converted to JAX ndarray or NumPy ndarray.
Convert a JaxArray
into a JAX ndarray.
# JaxArray.value is a JAX ndarray
b.value
DeviceArray([5, 1, 2, 3, 4], dtype=int32)
Convert a JaxArray
into a numpy ndarray.
# JaxArray can be easily converted to a numpy ndarray
np.asarray(b)
array([5, 1, 2, 3, 4])
Convert a numpy ndarray into a JaxArray
.
bm.asarray(np.arange(5))
JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))
Convert a JAX ndarray into a JaxArray
.
import jax.numpy as jnp
bm.asarray(jnp.arange(5))
JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))
bm.JaxArray(jnp.arange(5))
JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))
2. transformations in brainpy.math
also work on functions.#
APIs in other JAX frameworks can be naturally integrated in BrainPy. Let’s take the gradient-based optimization library Optax as an example to illustrate how to use other JAX frameworks in BrainPy.
import optax
# First create several useful functions.
network = bm.vmap(lambda params, x: bm.dot(params, x), in_axes=(None, 0))
def compute_loss(params, x, y):
y_pred = network(params, x)
loss = bm.mean(optax.l2_loss(y_pred, y))
return loss
@bm.jit
def train(params, opt_state, xs, ys):
grads = bm.grad(compute_loss)(params, xs.value, ys)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
# Generate some data
bm.random.seed(42)
target_params = 0.5
xs = bm.random.normal(size=(16, 2))
ys = bm.sum(xs * target_params, axis=-1)
# Initialize parameters of the model + optimizer
params = bm.array([0.0, 0.0])
optimizer = optax.adam(learning_rate=1e-1)
opt_state = optimizer.init(params)
# A simple update loop
for _ in range(1000):
params, opt_state = train(params, opt_state, xs, ys)
assert bm.allclose(params, target_params), \
'Optimization should retrieve the target params used to generate the data.'