Interoperation with other JAX frameworks#

BrainPy is designed to be easily interoperated with other JAX frameworks.

import jax
import brainpy as bp
# math library of BrainPy, JAX, NumPy
import brainpy.math as bm
import jax.numpy as jnp
import numpy as np

1. data are exchangeable among different frameworks.#

This can be realized because JaxArray can be direactly converted to JAX ndarray or NumPy ndarray.

Convert a JaxArray into a JAX ndarray.

b = bm.random.randint(10, size=5)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# JaxArray.value is a JAX's DeviceArray
b.value
DeviceArray([9, 9, 0, 4, 7], dtype=int32)

Convert a JaxArray into a numpy ndarray.

# JaxArray can be easily converted to a numpy ndarray
np.asarray(b)
array([9, 9, 0, 4, 7])

Convert a numpy ndarray into a JaxArray.

bm.asarray(np.arange(5))
JaxArray([0, 1, 2, 3, 4], dtype=int32)

Convert a JAX ndarray into a JaxArray.

bm.asarray(jnp.arange(5))
JaxArray([0, 1, 2, 3, 4], dtype=int32)
bm.JaxArray(jnp.arange(5))
JaxArray([0, 1, 2, 3, 4], dtype=int32)

2. transformations in brainpy.math also work on functions.#

APIs in other JAX frameworks can be naturally integrated in BrainPy. Let’s take the gradient-based optimization library Optax as an example to illustrate how to use other JAX frameworks in BrainPy.

import optax
# First create several useful functions.

network = jax.vmap(lambda params, x: bm.dot(params, x), in_axes=(None, 0))
optimizer = optax.adam(learning_rate=1e-1)

def compute_loss(params, x, y):
  y_pred = network(params, x)
  loss = bm.mean(optax.l2_loss(y_pred, y))
  return loss

@bm.jit
def train(params, opt_state, xs, ys):
  grads = bm.grad(compute_loss)(params, xs.value, ys)
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state
# Generate some data

bm.random.seed(42)
target_params = 0.5
xs = bm.random.normal(size=(16, 2))
ys = bm.sum(xs * target_params, axis=-1)
# Initialize parameters of the model + optimizer

params = bm.array([0.0, 0.0])
opt_state = optimizer.init(params)
# A simple update loop

for _ in range(1000):
  params, opt_state = train(params, opt_state, xs, ys)

assert bm.allclose(params, target_params), \
  'Optimization should retrieve the target params used to generate the data.'

3. other JAX frameworks can be integrated into a BrainPy program.#

In this example, we use the Flax, a library used for deep neural networks, to define a convolutional neural network (CNN). The, we integrate this CNN model into our RNN model which defined by BrainPy’s syntax.

Here, we first use flax to define a CNN network.

from flax import linen as nn

class CNN(nn.Module):
  """A CNN model implemented by using Flax."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    return x

Then, we define an RNN model by using our BrainPy interface.

from jax.tree_util import tree_flatten, tree_map, tree_unflatten

class Network(bp.dyn.DynamicalSystem):
  """A network model implemented by BrainPy"""

  def __init__(self):
    super(Network, self).__init__()

    # cnn and its parameters
    self.cnn = CNN()
    rng = bm.random.DEFAULT.split_key()
    params = self.cnn.init(rng, jnp.ones([1, 4, 28, 1]))['params']
    leaves, self.tree = tree_flatten(params)
    self.implicit_vars.update(tree_map(bm.TrainVar, leaves))

    # rnn
    self.rnn = bp.layers.GRU(256, 100)

    # readout
    self.linear = bp.layers.Dense(100, 10)

  def update(self, sha, x):
    params = tree_unflatten(self.tree, [v.value for v in self.implicit_vars.values()])
    x = self.cnn.apply({'params': params}, bm.as_jax(x))
    x = self.rnn(sha, x)
    x = self.linear(sha, x)
    return x

We initialize the network, optimizer, loss function, and BP trainer.

net = Network()
opt = bp.optim.Momentum(0.1)

def loss_func(predictions, targets):
  logits = bm.max(predictions, axis=1)
  loss = bp.losses.cross_entropy_loss(logits, targets)
  accuracy = bm.mean(bm.argmax(logits, -1) == targets)
  return loss, {'accuracy': accuracy}

trainer = bp.train.BPTT(net, loss_fun=loss_func, optimizer=opt, loss_has_aux=True)

We get the MNIST dataset.

train_dataset = bp.datasets.MNIST(r'D:\data\mnist', train=True, download=True)
X = train_dataset.data.reshape((-1, 7, 4, 28, 1)) / 255
Y = train_dataset.targets

Finally, train our defined model by using BPTT.fit() function.

trainer.fit([X, Y], batch_size=256, num_epoch=10)
Train 100 steps, use 32.5824 s, train loss 0.96465, accuracy 0.66015625
Train 200 steps, use 30.9035 s, train loss 0.38974, accuracy 0.89453125
Train 300 steps, use 33.1075 s, train loss 0.31525, accuracy 0.890625
Train 400 steps, use 31.4062 s, train loss 0.23846, accuracy 0.91015625
Train 500 steps, use 32.3371 s, train loss 0.21995, accuracy 0.9296875
Train 600 steps, use 32.5692 s, train loss 0.20885, accuracy 0.92578125
Train 700 steps, use 33.0139 s, train loss 0.24748, accuracy 0.90625
Train 800 steps, use 31.9635 s, train loss 0.14563, accuracy 0.953125
Train 900 steps, use 31.8845 s, train loss 0.17017, accuracy 0.94140625
Train 1000 steps, use 32.0537 s, train loss 0.09413, accuracy 0.95703125
Train 1100 steps, use 32.3714 s, train loss 0.06015, accuracy 0.984375
Train 1200 steps, use 31.6957 s, train loss 0.12061, accuracy 0.94921875
Train 1300 steps, use 31.8346 s, train loss 0.13908, accuracy 0.953125
Train 1400 steps, use 31.5252 s, train loss 0.10718, accuracy 0.953125
Train 1500 steps, use 31.7274 s, train loss 0.07869, accuracy 0.96875
Train 1600 steps, use 32.3928 s, train loss 0.08295, accuracy 0.96875
Train 1700 steps, use 31.7718 s, train loss 0.07569, accuracy 0.96484375
Train 1800 steps, use 31.9243 s, train loss 0.08607, accuracy 0.9609375
Train 1900 steps, use 32.2454 s, train loss 0.04332, accuracy 0.984375
Train 2000 steps, use 31.6231 s, train loss 0.02369, accuracy 0.9921875
Train 2100 steps, use 31.7800 s, train loss 0.03862, accuracy 0.9765625
Train 2200 steps, use 31.5431 s, train loss 0.01871, accuracy 0.9921875
Train 2300 steps, use 32.1064 s, train loss 0.03255, accuracy 0.9921875