Array operations with ein_rearrange, ein_reduce, and ein_repeat#

Colab Open in Kaggle

We don’t write

y = x.transpose(0, 2, 3, 1)

We write comprehensible code

y = bm.ein_rearrange(x, 'b c h w -> b h w c')

What’s in this tutorial?#

  • fundamentals: reordering, composition and decomposition of axes

  • operations: ein_rearrange, ein_reduce, ein_repeat

  • how much you can do with a single operation!

Preparations#

# Examples are given for numpy. This code also setups ipython/jupyter
# so that numpy arrays in the output are displayed as images
import numpy

import brainpy.math as bm

Load a batch of images to play with#

Please download the data.

ims = numpy.load('./test_images.npy', allow_pickle=False)
# There are 6 images of shape 96x96 with 3 color channels packed into tensor
print(ims.shape, ims.dtype)
(6, 96, 96, 3) float64
# display the first image (whole 4d tensor can't be rendered)
ims[0].shape
(96, 96, 3)
# second image in a batch
ims[1].shape
(96, 96, 3)
# rearrange, as its name suggests, rearranges elements
# below we swapped height and width.
# In other words, transposed first two axes (dimensions)
bm.ein_rearrange(ims[0], 'h w c -> w h c').shape
(96, 96, 3)

Composition of axes#

transposition is very common and useful, but let’s move to other capabilities provided by einops

# einops allows seamlessly composing batch and height to a new height dimension
# We just rendered all images by collapsing to 3d tensor!
bm.ein_rearrange(ims, 'b h w c -> (b h) w c').shape
(576, 96, 3)
# or compose a new dimension of batch and width
bm.ein_rearrange(ims, 'b h w c -> h (b w) c').shape
(96, 576, 3)
# resulting dimensions are computed very simply
# length of newly composed axis is a product of components
# [6, 96, 96, 3] -> [96, (6 * 96), 3]
bm.ein_rearrange(ims, 'b h w c -> h (b w) c').shape
(96, 576, 3)
# we can compose more than two axes. 
# let's flatten 4d array into 1d, resulting array has as many elements as the original
bm.ein_rearrange(ims, 'b h w c -> (b h w c)').shape
(165888,)

Decomposition of axis#

# decomposition is the inverse process - represent an axis as a combination of new axes
# several decompositions possible, so b1=2 is to decompose 6 to b1=2 and b2=3
bm.ein_rearrange(ims, '(b1 b2) h w c -> b1 b2 h w c ', b1=2).shape
(2, 3, 96, 96, 3)
# finally, combine composition and decomposition:
bm.ein_rearrange(ims, '(b1 b2) h w c -> (b1 h) (b2 w) c ', b1=2).shape
(192, 288, 3)
# slightly different composition: b1 is merged with width, b2 with height
# ... so letters are ordered by w then by h
bm.ein_rearrange(ims, '(b1 b2) h w c -> (b2 h) (b1 w) c ', b1=2).shape
(288, 192, 3)
# move part of width dimension to height. 
# we should call this width-to-height as image width shrunk by 2 and height doubled. 
# but all pixels are the same!
# Can you write reverse operation (height-to-width)?
bm.ein_rearrange(ims, 'b h (w w2) c -> (h w2) (b w) c', w2=2).shape
(192, 288, 3)

Order of axes matters#

# compare with the next example
bm.ein_rearrange(ims, 'b h w c -> h (b w) c').shape
(96, 576, 3)
# order of axes in composition is different
# rule is just as for digits in the number: leftmost digit is the most significant, 
# while neighboring numbers differ in the rightmost axis.

# you can also think of this as lexicographic sort
bm.ein_rearrange(ims, 'b h w c -> h (w b) c').shape
(96, 576, 3)
# what if b1 and b2 are reordered before composing to width?
bm.ein_rearrange(ims, '(b1 b2) h w c -> h (b1 b2 w) c ', b1=2).shape 
(96, 576, 3)
bm.ein_rearrange(ims, '(b1 b2) h w c -> h (b2 b1 w) c ', b1=2).shape 
(96, 576, 3)

Meet einops.reduce#

In einops-land you don’t need to guess what happened

x.mean(-1)

Because you write what the operation does

bm.ein_reduce(x, 'b h w c -> b h w', 'mean')

if axis is not present in the output — you guessed it — axis was reduced.

# average over batch
bm.ein_reduce(ims, 'b h w c -> h w c', 'mean').shape
(96, 96, 3)
# the previous is identical to familiar:
ims.mean(axis=0).shape
# but is so much more readable
(96, 96, 3)
# Example of reducing of several axes 
# besides mean, there are also min, max, sum, prod
bm.ein_reduce(ims, 'b h w c -> h w', 'min').shape
(96, 96)
# this is mean-pooling with 2x2 kernel
# image is split into 2x2 patches, each patch is averaged
bm.ein_reduce(ims, 'b (h h2) (w w2) c -> h (b w) c', 'mean', h2=2, w2=2).shape
(48, 288, 3)
# max-pooling is similar
# result is not as smooth as for mean-pooling
bm.ein_reduce(ims, 'b (h h2) (w w2) c -> h (b w) c', 'max', h2=2, w2=2).shape
(48, 288, 3)
# yet another example. Can you compute result shape?
bm.ein_reduce(ims, '(b1 b2) h w c -> (b2 h) (b1 w)', 'mean', b1=2).shape
(288, 192)

Stack and concatenate#

# rearrange can also take care of lists of arrays with the same shape
x = list(ims)
print(type(x), 'with', len(x), 'tensors of shape', x[0].shape)
# that's how we can stack inputs
# "list axis" becomes first ("b" in this case), and we left it there
res = bm.ein_rearrange(x, 'b h w c -> b h w c')

[r.shape for r in res]
<class 'list'> with 6 tensors of shape (96, 96, 3)
[(96, 96, 3), (96, 96, 3), (96, 96, 3), (96, 96, 3), (96, 96, 3), (96, 96, 3)]
# but new axis can appear in the other place:
bm.ein_rearrange(x, 'b h w c -> h w c b').shape
(96, 96, 3, 6)
# that's equivalent to numpy stacking, but written more explicitly
numpy.array_equal(bm.ein_rearrange(x, 'b h w c -> h w c b'), numpy.stack(x, axis=3))
False
# ... or we can concatenate along axes
bm.ein_rearrange(x, 'b h w c -> h (b w) c').shape
(96, 576, 3)
# which is equivalent to concatenation
numpy.array_equal(bm.ein_rearrange(x, 'b h w c -> h (b w) c'), numpy.concatenate(x, axis=1))
False

Addition or removal of axes#

You can write 1 to create a new axis of length 1. Similarly you can remove such axis.

There is also a synonym () that you can use. That’s a composition of zero axes and it also has a unit length.

x = bm.ein_rearrange(ims, 'b h w c -> b 1 h w 1 c') # functionality of numpy.expand_dims
print(x.shape)
print(bm.ein_rearrange(x, 'b 1 h w 1 c -> b h w c').shape) # functionality of numpy.squeeze
(6, 1, 96, 96, 1, 3)
(6, 96, 96, 3)
# compute max in each image individually, then show a difference 
x = bm.ein_reduce(ims, 'b h w c -> b () () c', 'max') - ims
bm.ein_rearrange(x, 'b h w c -> h (b w) c').shape
(96, 576, 3)

Repeating elements#

Third operation we introduce is repeat

# repeat along a new axis. New axis can be placed anywhere
bm.ein_repeat(ims[0], 'h w c -> h new_axis w c', new_axis=5).shape
(96, 5, 96, 3)
# shortcut
bm.ein_repeat(ims[0], 'h w c -> h 5 w c').shape
(96, 5, 96, 3)
# repeat along w (existing axis)
bm.ein_repeat(ims[0], 'h w c -> h (repeat w) c', repeat=3).shape
(96, 288, 3)
# repeat along two existing axes
bm.ein_repeat(ims[0], 'h w c -> (2 h) (2 w) c').shape
(192, 192, 3)
# order of axes matters as usual - you can repeat each element (pixel) 3 times 
# by changing order in parenthesis
bm.ein_repeat(ims[0], 'h w c -> h (w repeat) c', repeat=3).shape
(96, 288, 3)

Note: repeat operation covers functionality identical to numpy.repeat, numpy.tile and actually more than that.

Reduce ⇆ repeat#

reduce and repeat are like opposite of each other: first one reduces amount of elements, second one increases.

In the following example each image is repeated first, then we reduce over new axis to get back original tensor. Notice that operation patterns are “reverse” of each other

repeated = bm.ein_repeat(ims, 'b h w c -> b h new_axis w c', new_axis=2)
reduced = bm.ein_reduce(repeated, 'b h new_axis w c -> b h w c', 'min')


assert bm.allclose(ims, reduced)

Fancy examples in random order#

(a.k.a. mad designer gallery)

# interweaving pixels of different pictures
# all letters are observable
bm.ein_rearrange(ims, '(b1 b2) h w c -> (h b1) (w b2) c ', b1=2).shape
(192, 288, 3)
# interweaving along vertical for couples of images
bm.ein_rearrange(ims, '(b1 b2) h w c -> (h b1) (b2 w) c', b1=2).shape
(192, 288, 3)
# interweaving lines for couples of images
# exercise: achieve the same result without einops in your favourite framework
bm.ein_reduce(ims, '(b1 b2) h w c -> h (b2 w) c', 'max', b1=2).shape
(96, 288, 3)
# color can be also composed into dimension
# ... while image is downsampled
bm.ein_reduce(ims, 'b (h 2) (w 2) c -> (c h) (b w)', 'mean').shape
(144, 288)
# disproportionate resize
bm.ein_reduce(ims, 'b (h 4) (w 3) c -> (h) (b w)', 'mean').shape
(24, 192)
# spilt each image in two halves, compute mean of the two
bm.ein_reduce(ims, 'b (h1 h2) w c -> h2 (b w)', 'mean', h1=2).shape
(48, 576)
# split in small patches and transpose each patch
bm.ein_rearrange(ims, 'b (h1 h2) (w1 w2) c -> (h1 w2) (b w1 h2) c', h2=8, w2=8).shape
(96, 576, 3)
# stop me someone!
bm.ein_rearrange(ims, 'b (h1 h2 h3) (w1 w2 w3) c -> (h1 w2 h3) (b w1 h2 w3) c', h2=2, w2=2, w3=2, h3=2).shape
(96, 576, 3)
bm.ein_rearrange(ims, '(b1 b2) (h1 h2) (w1 w2) c -> (h1 b1 h2) (w1 b2 w2) c', h1=3, w1=3, b2=3).shape
(192, 288, 3)
# patterns can be arbitrarily complicated
bm.ein_reduce(ims, '(b1 b2) (h1 h2 h3) (w1 w2 w3) c -> (h1 w1 h3) (b1 w2 h2 w3 b2) c', 'mean', 
       h2=2, w1=2, w3=2, h3=2, b2=2).shape
(96, 576, 3)
# subtract background in each image individually and normalize
# pay attention to () - this is composition of 0 axis, a dummy axis with 1 element.
im2 = bm.ein_reduce(ims, 'b h w c -> b () () c', 'max') - ims
im2 /= bm.ein_reduce(im2, 'b h w c -> b () () c', 'max')
bm.ein_rearrange(im2, 'b h w c -> h (b w) c').shape
(96, 576, 3)
# pixelate: first downscale by averaging, then upscale back using the same pattern
averaged = bm.ein_reduce(ims, 'b (h h2) (w w2) c -> b h w c', 'mean', h2=6, w2=8)
bm.ein_repeat(averaged, 'b h w c -> (h h2) (b w w2) c', h2=6, w2=8).shape
(96, 576, 3)
bm.ein_rearrange(ims, 'b h w c -> w (b h) c').shape
(96, 576, 3)
# let's bring color dimension as part of horizontal axis
# at the same time horizontal axis is downsampled by 2x
bm.ein_reduce(ims, 'b (h h2) (w w2) c -> (h w2) (b w c)', 'mean', h2=3, w2=3).shape
(96, 576)

Summary#

  • rearrange doesn’t change number of elements and covers different numpy functions (like transpose, reshape, stack, concatenate, squeeze and expand_dims)

  • reduce combines same reordering syntax with reductions (mean, min, max, sum, prod, and any others)

  • repeat additionally covers repeating and tiling

  • composition and decomposition of axes are a corner stone, they can and should be used together