Inference with NMT

Learn to do inference with the trained NMT model.

How inferencing is different from the training model

Inferencing is slightly different from the training process for NMT. Because we don’t have a target sentence at the inference time, we need a way to trigger the decoder at the end of the encoding phase. It’s not difficult because we’ve already done the groundwork for this in the data we have. We simply kick off the decoder by using <s> as the first input to the decoder. Then, we recursively call the decoder using the predicted word as the input for the next time step. We continue this way until the model outputs </s> as the predicted token or reaches a predefined sentence length .

How to work with inference models

To do this, we have to define a new model using the existing weights of the training model. This is because our trained model is designed to consume a sequence of decoder inputs at once. We need a mechanism to recursively call the decoder. Here’s how we can define the inference model:

  • Define an encoder model that outputs the encoder’s hidden state sequence and the last encoder state.

  • Define a new decoder that takes a decoder input having a time dimension of 1 and a new input, to which we’ll input the previously hidden state value of the decoder (initialized with the encoder’s last state).

With that, we can start feeding data to generate predictions as follows:

  • Preprocess xsx_s as in data processing.

  • Feed xsx_s into GRUenc\text{GRU}_{enc} and calculate the encoder’s state sequence and the last state hh conditioned on xsx_s.

  • Initialize GRUdec\text{GRU}_{dec} with hh.

  • For the initial prediction step, predict Y^T2\hat{Y}_T^2 by conditioning the prediction on Y^T1=<s>\hat{Y}_T^1 = <s> as the first word and hh.

  • For subsequent time steps, while Y^Ti</s>\hat{Y}_T^i \neq </s> and predictions haven’t reached a predefined length threshold, predict Y^Tm+1\hat{Y}_T^{m+1} by conditioning the prediction on {Y^Tm,Y^Tm1,...,<s>}\left\{\hat{Y}_T^{m}, \hat{Y}_T^{m-1}, ..., <s> \right\} and hh .

This produces the translation given an input sequence of text:

Get hands-on with 1400+ tech skills courses.