Auto Vectorization
This lesson will provide an introduction to auto vectorization in JAX.
We'll cover the following...
Introduction
Those that are familiar with stochastic gradient descent (SGD) will know that it is applied one sample at a time, thus making it computationally inefficient. Instead, we use it in the batches in a technique usually known as minibatch gradient descent.
This batching operation is a common practice throughout the deep learning regime and can be used for various tasks like convolution, optimization, and so on.
Let’s have a look at a convolution function for 1D vectors:
Press + to interact
a = jnp.arange(5)b = jnp.arange(2,5)def Convolve(x, f):output = []for i in range(1, len(x)-1):output.append(x[i-1:i+2]@ f)return jnp.array(output)print(Convolve(a,b))
Batching
So far, the above function is only for a single pair of vectors. To apply it on a batch of vectors, we can ...