...

/

Flax: Overview of Convolution and Sequence Models

Flax: Overview of Convolution and Sequence Models

Let’s look at the structures for building CNNs and RNNs.

Convolutions

The importance of convolutional neural networks (CNNs) in different applications of computer vision, and even NLP, is well-established.

Since we have already covered the mechanics of convolution in the earlier chapters, it should be sufficient just to show the respective Flax syntax. Remember that XLA only provides the main convolution functions while Flax provides the high-level wrappers. Wrappers for Flax are:

  • Conv, used for traditional convolution.
  • ConvTranspose, used for transposed convolution.

Pooling

A pooling layer is similar to a strided convolution for shrinking the output size by an order of ss.

The difference, however, lies in the learning. While a strided kernel’s weights are learned and updated by the backpropagation, a pooling layer simply applies an average or max/min operator and doesn’t require any learning.

The output’s dimensions can be calculated using the formula:

x=mfs+1x = \frac{m-f}{s}+1

Flax provides support for both modes of pooling.

Max/min pooling

In max/min pooling, we convolve the filter and take the maximum and minimum of each f×ff\times f chunk, respectively.

For example, if we apply a max/min pooling filter of 2×22\times 2 on a 256×256256\times 256 image, the resulting image will be:

25622+1(128,128)\frac{256-2}{2}+1\Rightarrow (128,128)

2x2 max filter operation on a 6x6 matrix
Press + to interact
I = jnp.ones((1,256,256,1))
O = nn.max_pool(I,window_shape=(2, 2), strides=(2, 2))
print(O.shape) #confirming the formula's results

Note: Notice that there is a subtle difference from the default LAX function because we specified the batch size at the start.

Average pooling

In average pooling, we take the average of each f×ff\times f chunk, respectively.

Since the formula is general for both types of pooling, we’ll get the same output dimensions in both cases.

25622+1(128,128)\frac{256-2}{2}+1\Rightarrow (128,128)

Sequence models

In a number of areas like natural language processing, bioinformatics, speech processing, and DSP, we need to employ sequence models.

LSTM

Long-Short Term Memory (or simply LSTM) was proposed by Sepp Hochreiter and Jürgen Schmidhuber to address the shortcomings of traditional RNN in 1997.

The basic idea behind LSTM is to augment a memory cell, ctc_t with the hidden state, hth_t.

The memory cell is controlled by three cells:

  • Input gate, iti_t, which determines what goes inside the memory cell.
  • Output gate, oto_t
...
LSTM consists of input, output and forget gates