Use Flax modules as a part of the BrainPy program

Use Flax modules as a part of the BrainPy program#

Colab Open in Kaggle

import brainpy as bp
import brainpy.math as bm
import brainpy_datasets as bd
from functools import partial
from flax import linen as nn
bm.set(mode=bm.training_mode, dt=1.)
bp.__version__
'2.4.1'

In this example, we use the Flax, a library used for deep neural networks, to define a convolutional neural network (CNN). The, we integrate this CNN model into our RNN model which defined by BrainPy’s syntax.

Here, we first use flax to define a CNN network.

class CNN(nn.Module):
  """A CNN model implemented by using Flax."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    return x

Then, we define an RNN model by using our BrainPy interface. Note here, the Flax module is used as a module at one single step.

class Network(bp.DynamicalSystemNS):
  def __init__(self):
    super(Network, self).__init__()
    self.cnn = bp.dnn.FromFlax(
      CNN(), # the model
      bm.ones([1, 4, 28, 1])  # an example of the input used to initialize the model parameters
    )
    self.rnn = bp.dyn.GRUCell(256, 100)
    self.linear = bp.dnn.Dense(100, 10)

  def update(self, x):
    x = self.cnn(x)
    x = self.rnn(x)
    x = self.linear(x)
    return x

We initialize the network, optimizer, loss function, and BP trainer.

net = Network()
opt = bp.optim.Momentum(0.1)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

We get the MNIST dataset.

data = bd.vision.MNIST(r'D:\data', download=True)
data.data = data.data.reshape(-1, 7, 4, 28, 1) / 255


def get_data(batch_size):
  key = bm.random.split_key()
  data.data = bm.random.permutation(data.data, key=key)
  data.targets = bm.random.permutation(data.targets, key=key)

  for i in range(0, len(data), batch_size):
    yield data.data[i: i + batch_size], data.targets[i: i + batch_size]
def loss_func(predictions, targets):
  logits = bm.max(predictions, axis=1)
  loss = bp.losses.cross_entropy_loss(logits, targets)
  accuracy = bm.mean(bm.argmax(logits, -1) == targets)
  return loss, {'accuracy': accuracy}

Finally, train our defined model by using BPTT.fit() function.

trainer = bp.BPTT(net, loss_fun=loss_func, optimizer=opt, loss_has_aux=True)
trainer.fit(partial(get_data, batch_size=256), num_epoch=10)
Train 0 epoch, use 104.2070 s, loss 1.0793957710266113, accuracy 0.616583526134491
Train 1 epoch, use 85.4961 s, loss 0.4177210330963135, accuracy 0.8495622277259827
Train 2 epoch, use 85.1781 s, loss 0.27014848589897156, accuracy 0.9093307256698608
Train 3 epoch, use 85.4031 s, loss 0.23874548077583313, accuracy 0.9184618592262268
Train 4 epoch, use 86.0905 s, loss 0.21281874179840088, accuracy 0.925542950630188
Train 5 epoch, use 85.5581 s, loss 0.19409772753715515, accuracy 0.9322085380554199
Train 6 epoch, use 85.9805 s, loss 0.18303607404232025, accuracy 0.9356383085250854
Train 7 epoch, use 85.0740 s, loss 0.16687186062335968, accuracy 0.9404421448707581
Train 8 epoch, use 85.7086 s, loss 0.1607382893562317, accuracy 0.9421210289001465
Train 9 epoch, use 87.4538 s, loss 0.15550467371940613, accuracy 0.9443760514259338