...

/

Visualizing Attention Patterns

Visualizing Attention Patterns

Learn to visualize attention patterns to gain deeper insights into how they work.

Remember that we specifically defined a model called attention_visualizer to generate attention matrices? With the model trained, we can now look at these attention patterns by feeding data to the model. Here’s how the model was defined:

attention_visualizer = tf.keras.models.Model(inputs=[encoder.inputs,
decoder_input], outputs=[attn_weights, decoder_out])
Press + to interact

The get_attention_matrix_for_sampled_data() function

We’ll also define a function to get the processed attention matrix along with label data that we can use directly for visualization purposes:

Press + to interact
def get_attention_matrix_for_sampled_data(attention_model, target_lookup_
layer, test_xy, n_samples=5):
test_x, test_y = test_xy
rand_ids = np.random.randint(0, len(test_xy[0]),
size=(n_samples,))
results = []
for rid in rand_ids:
en_input = test_x[rid:rid+1]
de_input = test_y[rid:rid+1,:-1]
attn_weights, predictions = attention_model.predict([en_input,
de_input])
predicted_word_ids = np.argmax(predictions, axis=-1).ravel()
predicted_words = [target_lookup_layer.get_vocabulary()[wid]
for wid in predicted_word_ids]
clean_en_input = []
en_start_i = 0
for i, w in enumerate(en_input.ravel()):
if w=='<pad>':
en_start_i = i+1
continue
clean_en_input.append(w)
if w=='</s>': break
clean_predicted_words = []
for w in predicted_words:
clean_predicted_words.append(w)
if w=='</s>': break
results.append(
{
"attention_weights": attn_weights[
0,:len(clean_predicted_words),en_start_i:en_start_
i+len(clean_en_input)
],
"input_words": clean_en_input,
"predicted_words": clean_predicted_words
}
)
return results
...