Integrate BrainPy Models into Flax (Example 2)#
In this example, we use brainpy.layers.Conv2dLSTMCell
as a recurrent cell in the Flax computation. Different from brainpy.neurons.LIF
, brainpy.layers.Conv2dLSTMCell
have trainable parameters.
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.metrics import tensorboard
from flax.training import train_state
import brainpy as bp
import brainpy.math as bm
bm.set(mode=bm.training_mode, dt=1.)
bp.__version__
num_time = 10
# the recurrent cell with trainable parameters
cell1 = bp.dnn.ToFlaxRNNCell(bp.dyn.Conv2dLSTMCell((28, 28),
in_channels=1,
out_channels=32,
kernel_size=(3, 3)))
cell2 = bp.dnn.ToFlaxRNNCell(bp.dyn.Conv2dLSTMCell((14, 14),
in_channels=32,
out_channels=64,
kernel_size=(3, 3)))
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.RNN(cell1, (28, 28, 32))(x) # Use RNN to unfold the transformed recurrent cell
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.RNN(cell2, (14, 14, 64))(x) # Use RNN to unfold the transformed recurrent cell
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], x.shape[1], -1))
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
@jax.jit
def apply_model(state, images, labels):
"""Computes gradients, loss and accuracy for a single batch."""
images = jnp.expand_dims(images, axis=1)
images = jnp.tile(images, (1, num_time, 1, 1, 1))
def loss_fn(params):
logits = state.apply_fn({'params': params}, images)
logits = bm.max(logits, axis=1).value
one_hot = jax.nn.one_hot(labels, 10)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return grads, loss, accuracy
@jax.jit
def update_model(state, grads):
return state.apply_gradients(grads=grads)
def train_epoch(state, train_ds, batch_size, rng):
"""Train for a single epoch."""
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rng, len(train_ds['image']))
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_images = train_ds['image'][perm, ...]
batch_labels = train_ds['label'][perm, ...]
grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
state = update_model(state, grads)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
train_loss = np.mean(epoch_loss)
train_accuracy = np.mean(epoch_accuracy)
return state, train_loss, train_accuracy
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.asarray(train_ds['image']) / 255.
test_ds['image'] = jnp.asarray(test_ds['image']) / 255.
return train_ds, test_ds
def create_train_state(rng, config):
"""Creates initial `TrainState`."""
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, num_time, 28, 28, 1]))['params']
tx = optax.sgd(config.learning_rate, config.momentum)
return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)
def train_and_evaluate(config: ml_collections.ConfigDict,
workdir: str) -> train_state.TrainState:
"""Execute model training and evaluation loop.
Args:
config: Hyperparameter configuration for training and evaluation.
workdir: Directory where the tensorboard summaries are written to.
Returns:
The train state (which includes the `.params`).
"""
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(0)
summary_writer = tensorboard.SummaryWriter(workdir)
summary_writer.hparams(dict(config))
rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng, config)
for epoch in range(1, config.num_epochs + 1):
rng, input_rng = jax.random.split(rng)
state, train_loss, train_accuracy = train_epoch(state,
train_ds,
config.batch_size,
input_rng)
test_losses, test_accs = [], []
for i in range(0, test_ds['image'].shape[0], config.batch_size):
_, test_loss, test_accuracy = apply_model(state,
test_ds['image'][i: i + config.batch_size],
test_ds['label'][i: i + config.batch_size])
test_losses.append(test_loss)
test_accs.append(test_accuracy)
test_loss = np.mean(test_loss)
test_accuracy = np.mean(test_accs)
print(
'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f'
% (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100)
)
summary_writer.scalar('train_loss', train_loss, epoch)
summary_writer.scalar('train_accuracy', train_accuracy, epoch)
summary_writer.scalar('test_loss', test_loss, epoch)
summary_writer.scalar('test_accuracy', test_accuracy, epoch)
summary_writer.flush()
return state
config = ml_collections.ConfigDict()
config.learning_rate = 0.1
config.momentum = 0.9
config.batch_size = 128
config.num_epochs = 10
train_and_evaluate(config, './ckpt')