Interoperation with other JAX frameworks#

Colab Open in Kaggle

BrainPy is designed to be easily interoperated with other JAX frameworks.

import brainpy as bp
import brainpy_datasets as bd
import brainpy.math as bm
import jax.numpy as jnp
import numpy as np

1. data are exchangeable among different frameworks.#

This can be realized because Array can be direactly converted to JAX ndarray or NumPy ndarray.

Convert a Array into a JAX ndarray.

b = bm.random.randint(10, size=5)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# Array.value is a JAX's DeviceArray
b.value
DeviceArray([9, 9, 0, 4, 7], dtype=int32)

Convert a Array into a numpy ndarray.

# Array can be easily converted to a numpy ndarray
np.asarray(b)
array([9, 9, 0, 4, 7])

Convert a numpy ndarray into a Array.

bm.asarray(np.arange(5))
Array([0, 1, 2, 3, 4], dtype=int32)

Convert a JAX ndarray into a Array.

bm.asarray(jnp.arange(5))
Array([0, 1, 2, 3, 4], dtype=int32)
bm.Array(jnp.arange(5))
Array([0, 1, 2, 3, 4], dtype=int32)

2. other JAX frameworks can be integrated into a BrainPy program.#

In this example, we use the Flax, a library used for deep neural networks, to define a convolutional neural network (CNN). The, we integrate this CNN model into our RNN model which defined by BrainPy’s syntax.

Here, we first use flax to define a CNN network.

from flax import linen as nn

class CNN(nn.Module):
  """A CNN model implemented by using Flax."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    return x

Then, we define an RNN model by using our BrainPy interface.

class Network(bp.DynamicalSystemNS):
  """A network model implemented by BrainPy"""

  def __init__(self):
    super(Network, self).__init__()
    self.cnn = bp.interop.FromFlax(CNN(), bm.ones([1, 4, 28, 1]))
    self.rnn = bp.layers.GRUCell(256, 100)
    self.linear = bp.layers.Dense(100, 10)

  def update(self, x):
    x = self.cnn(x)
    x = self.rnn(x)
    x = self.linear(x)
    return x

We initialize the network, optimizer, loss function, and BP trainer.

net = Network()
opt = bp.optim.Momentum(0.1)

def loss_func(predictions, targets):
  logits = bm.max(predictions, axis=1)
  loss = bp.losses.cross_entropy_loss(logits, targets)
  accuracy = bm.mean(bm.argmax(logits, -1) == targets)
  return loss, {'accuracy': accuracy}

trainer = bp.BPTT(net, loss_fun=loss_func, optimizer=opt, loss_has_aux=True)

We get the MNIST dataset.

train_dataset = bd.vision.MNIST(r'D:\data', download=True)
X = train_dataset.data.reshape((-1, 7, 4, 28, 1)) / 255
Y = train_dataset.targets

Finally, train our defined model by using BPTT.fit() function.

trainer.fit([X, Y], batch_size=256, num_epoch=10)
Train 100 steps, use 32.5824 s, train loss 0.96465, accuracy 0.66015625
Train 200 steps, use 30.9035 s, train loss 0.38974, accuracy 0.89453125
Train 300 steps, use 33.1075 s, train loss 0.31525, accuracy 0.890625
Train 400 steps, use 31.4062 s, train loss 0.23846, accuracy 0.91015625
Train 500 steps, use 32.3371 s, train loss 0.21995, accuracy 0.9296875
Train 600 steps, use 32.5692 s, train loss 0.20885, accuracy 0.92578125
Train 700 steps, use 33.0139 s, train loss 0.24748, accuracy 0.90625
Train 800 steps, use 31.9635 s, train loss 0.14563, accuracy 0.953125
Train 900 steps, use 31.8845 s, train loss 0.17017, accuracy 0.94140625
Train 1000 steps, use 32.0537 s, train loss 0.09413, accuracy 0.95703125
Train 1100 steps, use 32.3714 s, train loss 0.06015, accuracy 0.984375
Train 1200 steps, use 31.6957 s, train loss 0.12061, accuracy 0.94921875
Train 1300 steps, use 31.8346 s, train loss 0.13908, accuracy 0.953125
Train 1400 steps, use 31.5252 s, train loss 0.10718, accuracy 0.953125
Train 1500 steps, use 31.7274 s, train loss 0.07869, accuracy 0.96875
Train 1600 steps, use 32.3928 s, train loss 0.08295, accuracy 0.96875
Train 1700 steps, use 31.7718 s, train loss 0.07569, accuracy 0.96484375
Train 1800 steps, use 31.9243 s, train loss 0.08607, accuracy 0.9609375
Train 1900 steps, use 32.2454 s, train loss 0.04332, accuracy 0.984375
Train 2000 steps, use 31.6231 s, train loss 0.02369, accuracy 0.9921875
Train 2100 steps, use 31.7800 s, train loss 0.03862, accuracy 0.9765625
Train 2200 steps, use 31.5431 s, train loss 0.01871, accuracy 0.9921875
Train 2300 steps, use 32.1064 s, train loss 0.03255, accuracy 0.9921875