Training the NMT

Learn to train the NMT model with TensorFlow.

Now that we have defined the NMT architecture and preprocessed the training data, it’s quite straightforward to train the model. Here, we’ll define and illustrate the exact process used for training:

Press + to interact
The training procedure for NMT
The training procedure for NMT

The prepare_data() function

For the model training, we’re going to define a custom training loop because there’s a special metric we’d like to track. Unfortunately, this metric is not a readily available TensorFlow metric. But before that, there are several utility functions we need to define:

Press + to interact
def prepare_data(de_lookup_layer, train_xy, valid_xy, test_xy):
""" Create a data dictionary from the DataFrames containing data
"""
data_dict = {}
for label, data_xy in zip(['train', 'valid', 'test'], [train_xy, valid_xy, test_xy]):
data_x, data_y = data_xy
en_inputs = data_x
de_inputs = data_y[:,:-1]
de_labels = de_lookup_layer(data_y[:,1:]).numpy()
data_dict[label] = {'encoder_inputs': en_inputs, 'decoder_inputs': de_inputs, 'decoder_labels': de_labels}
return data_dict

The prepare_data() function takes the source sentence and target sentence pairs and generates encoder and decoder inputs and decoder labels. Let’s look at the arguments:

  • de_lookup_layer: The StringLookup layer of the German language.

  • train_xy: A tuple containing tokenized English sentences and tokenized German sentences in the training set, respectively.

  • valid_xy: Similar to train_xy but for validation data.

  • test_xy: Similar to train_xy but for test data.

For each training, validation, and test dataset, this function generates the following: ...