Integrate BrainPy models into Flax (Example 1)

Integrate BrainPy models into Flax (Example 1)#

Colab Open in Kaggle

In this example, we use brainpy.neurons.LIF as a recurrent cell in the Flax computation. brainpy.neurons.LIF only has recurrent variables, does not have trainable parameters.

import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.metrics import tensorboard
from flax.training import train_state

import brainpy as bp
import brainpy.math as bm
bm.set(mode=bm.training_mode, dt=1.)

bp.__version__
num_time = 10
pars = dict(tau=10, V_reset=0, V_rest=0, V_th=0.1, keep_size=True, input_var=False)
# LIF neurons can be viewed as a recurrent cell without trainable parameters
cell1 = bp.dnn.ToFlaxRNNCell(bp.neurons.LIF((28, 28, 32), **pars))
cell2 = bp.dnn.ToFlaxRNNCell(bp.neurons.LIF((14, 14, 64), **pars))
cell3 = bp.dnn.ToFlaxRNNCell(bp.neurons.LIF(256, **pars))
class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.RNN(cell1, cell1.model.varshape)(x)  # Use RNN to unfold the recurrent LIF
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.RNN(cell2, cell2.model.varshape)(x)  # Use RNN to unfold the recurrent LIF
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], x.shape[1], -1))
    x = nn.Dense(features=256)(x)
    x = nn.RNN(cell3, cell3.model.varshape)(x)  # Use RNN to unfold the recurrent LIF
    x = nn.Dense(features=10)(x)
    return x
@jax.jit
def apply_model(state, images, labels):
  """Computes gradients, loss and accuracy for a single batch."""
  images = jnp.expand_dims(images, axis=1)
  images = jnp.tile(images, (1, num_time, 1, 1, 1))

  def loss_fn(params):
    logits = state.apply_fn({'params': params}, images)
    logits = bm.max(logits, axis=1).value
    one_hot = jax.nn.one_hot(labels, 10)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return grads, loss, accuracy
@jax.jit
def update_model(state, grads):
  return state.apply_gradients(grads=grads)
def train_epoch(state, train_ds, batch_size, rng):
  """Train for a single epoch."""
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))

  epoch_loss = []
  epoch_accuracy = []

  for perm in perms:
    batch_images = train_ds['image'][perm, ...]
    batch_labels = train_ds['label'][perm, ...]
    grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
    state = update_model(state, grads)
    epoch_loss.append(loss)
    epoch_accuracy.append(accuracy)
  train_loss = np.mean(epoch_loss)
  train_accuracy = np.mean(epoch_accuracy)
  return state, train_loss, train_accuracy
def get_datasets():
  """Load MNIST train and test datasets into memory."""
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.asarray(train_ds['image']) / 255.
  test_ds['image'] = jnp.asarray(test_ds['image']) / 255.
  return train_ds, test_ds
def create_train_state(rng, config):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, num_time, 28, 28, 1]))['params']
  tx = optax.sgd(config.learning_rate, config.momentum)
  return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)
def train_and_evaluate(config: ml_collections.ConfigDict,
                       workdir: str) -> train_state.TrainState:
  """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    The train state (which includes the `.params`).
  """
  train_ds, test_ds = get_datasets()
  rng = jax.random.PRNGKey(0)

  summary_writer = tensorboard.SummaryWriter(workdir)
  summary_writer.hparams(dict(config))

  rng, init_rng = jax.random.split(rng)
  state = create_train_state(init_rng, config)

  for epoch in range(1, config.num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    state, train_loss, train_accuracy = train_epoch(state,
                                                    train_ds,
                                                    config.batch_size,
                                                    input_rng)
    test_losses, test_accs = [], []
    for i in range(0, test_ds['image'].shape[0], config.batch_size):
      _, test_loss, test_accuracy = apply_model(state,
                                              test_ds['image'][i: i + config.batch_size],
                                              test_ds['label'][i: i + config.batch_size])
      test_losses.append(test_loss)
      test_accs.append(test_accuracy)
    test_loss = np.mean(test_loss)
    test_accuracy = np.mean(test_accs)

    print(
      'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f'
      % (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100)
    )

    summary_writer.scalar('train_loss', train_loss, epoch)
    summary_writer.scalar('train_accuracy', train_accuracy, epoch)
    summary_writer.scalar('test_loss', test_loss, epoch)
    summary_writer.scalar('test_accuracy', test_accuracy, epoch)

  summary_writer.flush()
  return state
config = ml_collections.ConfigDict()

config.learning_rate = 0.1
config.momentum = 0.9
config.batch_size = 128
config.num_epochs = 10

train_and_evaluate(config, './ckpt')