Two-dimensional Convolutions
This lesson will introduce two-dimensional convolution in JAX.
Note: The animations to explain convolution mechanics are used here with special thanks from Vincent Dumoulin and Francesco Visin’s “A guide to convolution arithmetic for deep learning” [arXiv:1603.07285]
We restricted the last lesson to one-dimensional convolution. In computer vision applications, however, we need to operate in more than one dimension. In this and subsequent lessons, we’ll up the ante by upgrading to two-dimensional convolutions.
We can extend the convolution to two-dimensions:
Since two-dimensional convolution is used frequently in computer vision applications, we’ll invest more time explaining its mechanics.
Throughout the examples, we will assume the following settings:
- The input image I (shown in blue) has the dimensions .
- The convolving kernel/filter, F having dimensions (square kernels are the usual standard).
- The output image O (shown in green) has the dimensions .
We used Scipy’s convolve()
as an N-dimensional convolution choice in the last lesson. We’ll go with a more solid foundation here.
JAX and its various neural network libraries provide a number of different convolution functions. Behind all those functions including Scipy’s) is the fundamental implementation of jax.lax.conv_general_dilated()
This function takes four (necessary) parameters:
- Input matrix
- Output matrix
- Stride use (1,1) by default
- Padding use [(0,0),(0,0)] by default
Note: Usually, 2D convolution requires a 4D volume due to channels and batch size, but we’ll keep it simple here by using single 2D matrices for ,, and .
Types of convolution
There are a few varieties of convolution, depending on whether or not we’re using a stride or padding. We’ll quickly review them.
In default mode, we convolve the filter/kernel over the input. The resulting image inevitably shrinks in size