Implementing and Training the Model with TensorFlow
Learn to implement and train the transformer model with TensorFlow.
We’ll now implement the model we just studied. First, let’s import a few things:
import tensorflow_hub as hubimport tensorflow as tfimport tensorflow.keras.backend as K
Implementing the ViT model
Next, we’re going to download the pretrained ViT model from TensorFlow Hub. We’ll be using a model submitted by Sayak Paul. You can see other ViT models here.
image_encoder = hub.KerasLayer("https://tfhub.dev/sayakpaul/vit_s16_fe/1", trainable=False)
We then define an input layer to input images and pass that to the image_encoder
to get the final feature vector for that image:
image_input = tf.keras.layers.Input(shape=(224, 224, 3))image_features = image_encoder(image_input)
We can look at the size of the final image representation by running:
print(f"Final representation shape: {image_features.shape}")
This will output:
Final representation shape: (None, 384)
Next, we’ll look at the details of how to implement the text-based transformer model, which will take in the image representation to generate the image caption.
Implementing the text-based decoder
Here, we’ll implement a transformer decoder model from the ground up. This is different from how we used transformer models before, where we downloaded a pretrained model and used them.
Before we implement the model itself, we’re going to implement two custom Keras layers: one for the self-attention mechanism and the other one to capture the functionality of a single layer in the transformer model. Let’s start with the self-attention layer.
Defining the self-attention layer
Here, we define the self-attention layer using the Keras subclassing API:
class SelfAttentionLayer(tf.keras.layers.Layer):""" Defines the computations in the self-attention layer """def __init__(self, d):super(SelfAttentionLayer, self).__init__()# Feature dimensionality of the outputself.d = ddef build(self, input_shape):# Query weight matrixself.Wq = self.add_weight(shape=(input_shape[-1], self.d),initializer='glorot_uniform',trainable=True, dtype='float32')# Key weight matrixself.Wk = self.add_weight(shape=(input_shape[-1], self.d),initializer='glorot_uniform',trainable=True, dtype='float32')# Value weight matrixself.Wv = self.add_weight(shape=(input_shape[-1], self.d),initializer='glorot_uniform',trainable=True, dtype='float32')def call(self, q_x, k_x, v_x, mask=None):q = tf.matmul(q_x,self.Wq) #[None, t, d]k = tf.matmul(k_x,self.Wk) #[None, t, d]v = tf.matmul(v_x,self.Wv) #[None, t, d]# Computing the final outputh = tf.keras.layers.Attention(causal=True)([q, #qv, #vk, #k], mask=[None, mask])# [None, t, t] . [None, t, d] => [None, t, d]return h
Here, we have to populate the logic for three functions:
• __init__()
and __build__()
: Define various hyperparameters and layer initialization-specific logic.
• call()
: Computations that need to happen when the layer is called.
We define the dimensionality of the attention output, d
, as an argument to the __init__()
method. Next, in the __build__()
method, we define three weight matrices: Wq
, Wk
, and Wv
. These represent the weights of the query, key, and value, respectively.
Finally, in the call()
method, we have the logic. It takes four inputs: query, key, value inputs, and an optional mask for values. We then compute the latent q
, k
, and v
by multiplying with the corresponding weight matrices Wq
, Wk
, and Wv
. To compute attention, we’ll be using the out-of-the-box layer tf.keras.layers.Attention
. The tf.keras.layers.Attention()
layer has several arguments. One that we care about here is setting causal=True
.
By doing this, we’re instructing the layer to mask the tokens ...