The Machine Learning Pipeline for Image Caption Generation
Learn to create the pipeline for image caption generation.
We'll cover the following
Here, we’ll look at the image caption generation pipeline at a very high level and then discuss it piece by piece until we have the full model. The image caption generation framework consists of two main components:
A pretrained vision transformer model to produce an image representation.
A text-based decoder model that can decode the image representation to a series of token IDs. This uses a text tokenizer to convert tokens to token IDs and vice versa.
Though the transformer models were initially used for text-based NLP problems, they have outgrown the domain of text data and have been used in other areas, such as image data and audio data.
Here, we’ll be using one transformer model that can process image data and another that can process text data.
Vision transformer (ViT)
First, let’s look at the transformer generating the encoded vector representations of images. We’ll be using a pretrained vision transformer (ViT) to achieve this. This model has been trained on the ImageNet dataset we discussed above. Let’s look at the architecture of this model.
Originally, the ViT was proposed in the paper
The idea is to decompose an image into small patches of 16 × 16 and consider each as a separate token. Each image path is flattened to a 1D vector, and its position is encoded by a positional encoding mechanism similar to the original transformer. But images are 2D structures; is it enough to have 1D positional information and not 2D positional information? The authors argue that a 1D positional encoding was adequate, and 2D positional encoding did not provide a significant boost. Once the image is broken into patches of 16 × 16 and flattened, each image can be presented as a sequence of tokens, just like a textual input sequence.
Then the model is pretrained in a self-supervised fashion using a vision dataset called JFT-300M. The paper proposes an elegant way to train the ViT in a semisupervised fashion using image data. Similar to how NLP problems represent a unit of text as a token, a token is a patch of an image (i.e., a sequence of continuous values where values are normalized pixels). Then the ViT is pretrained to predict the mean 3-bit RGB color of a given image patch. Each channel (i.e., red, green, and blue) is represented with 3 bits (each bit having a value of 0 or 1), which gives 512 possibilities or classes. In other words, for a given image, patches (similar to how tokens are treated in NLP) are masked randomly (using the same approach as BERT), and the model is asked to predict the mean 3-bit RGB color of that image patch.
After pretraining, the model can be fine-tuned for a task-specific problem by fitting a classification or a regression head on top of the ViT, just like BERT. The ViT also has the [CLS]
token at the beginning of the sequence, which will be used as the input representation for downstream vision models that are plugged on top of the ViT.
The figure below illustrates the mechanics of the ViT:
Get hands-on with 1400+ tech skills courses.