Visualizing Attention Patterns
Learn to visualize attention patterns to gain deeper insights into how they work.
Remember that we specifically defined a model called attention_visualizer
to generate attention matrices? With the model trained, we can now look at these attention patterns by feeding data to the model. Here’s how the model was defined:
attention_visualizer = tf.keras.models.Model(inputs=[encoder.inputs,decoder_input], outputs=[attn_weights, decoder_out])
Press + to interact
The get_attention_matrix_for_sampled_data()
function
We’ll also define a function to get the processed attention matrix along with label data that we can use directly for visualization purposes:
Press + to interact
def get_attention_matrix_for_sampled_data(attention_model, target_lookup_layer, test_xy, n_samples=5):test_x, test_y = test_xyrand_ids = np.random.randint(0, len(test_xy[0]),size=(n_samples,))results = []for rid in rand_ids:en_input = test_x[rid:rid+1]de_input = test_y[rid:rid+1,:-1]attn_weights, predictions = attention_model.predict([en_input,de_input])predicted_word_ids = np.argmax(predictions, axis=-1).ravel()predicted_words = [target_lookup_layer.get_vocabulary()[wid]for wid in predicted_word_ids]clean_en_input = []en_start_i = 0for i, w in enumerate(en_input.ravel()):if w=='<pad>':en_start_i = i+1continueclean_en_input.append(w)if w=='</s>': breakclean_predicted_words = []for w in predicted_words:clean_predicted_words.append(w)if w=='</s>': breakresults.append({"attention_weights": attn_weights[0,:len(clean_predicted_words),en_start_i:en_start_i+len(clean_en_input)],"input_words": clean_en_input,"predicted_words": clean_predicted_words})return results
...