Improving LSTMs

Learn about the techniques for improving the performance of LSTMs.

Having a model backed up by solid foundations doesn’t always guarantee pragmatic success when used in the real world. Natural language is quite complex. Sometimes, seasoned writers struggle to produce quality content. So, we can’t expect LSTMs to magically output meaningful, well-written content suddenly. Having a sophisticated design—allowing for better modeling of long-term dependencies in the data—does help, but we need more techniques during inference to produce better text. Therefore, numerous extensions have been developed to help LSTMs perform better at the prediction stage. Here, we’ll discuss several such improvements: greedy sampling, beam search, using word vectors instead of a one-hot encoded representation of words, and using bidirectional LSTMs. It’s important to note that these optimization techniques are not specific to LSTMs; rather, any sequential model can benefit from them.

Greedy sampling

If we try to always predict the word with the highest probability, the LSTM will tend to produce very monotonic results. For example, due to the frequent occurrence of stop words (e.g., “the”), it may repeat them many times before switching to another word.

One way to get around this is to use greedy sampling, where we pick the predicted best n and sample from that set. This helps to break the monotonic nature of the predictions.

Let’s consider the first sentence of the previous example:

John gave Mary a puppy.

Say we start with the first word and want to predict the next four words:

John _______ _______ ________ ________.

If we attempt to choose samples deterministically, the LSTM might output something like the following:

John gave Mary gave John.

However, by sampling the next word from a subset of words in the vocabulary (most highly probable ones), the LSTM is forced to vary the prediction and might output the following:

John gave Mary a puppy.

Alternatively, it might give the following output:

John gave puppy a puppy.

However, even though greedy sampling helps to add more flavor/diversity to the generated text, this method doesn’t guarantee that the output will always be realistic, especially when outputting longer sequences of text. Now, we’ll see a better search technique that actually looks ahead several steps before predictions.

Beam search

Beam search is a way of helping with the quality of the predictions produced by the LSTM. In this, the predictions are found by solving a search problem. Particularly, we predict several steps ahead for multiple candidates at each step. This gives rise to a tree-like structure with candidate sequences of words. The crucial idea of beam search is to produce the b outputs (that is, yt,yt+1,...,yt+by_t, y_{t+1},..., y_{t+b}) at once instead of a single output yty_t. Here, b is known as the length of the beam, and the b outputs produced are known as the beam. More technically, we pick the beam that has the highest joint probability P(yt,yt+1,,yt+bxt)P(y_t, y_{t+1},…, y_{t+b}|x_t) instead of picking the highest probable P(ytxt)P(y_t|x_t). We’re looking farther into the future before making a prediction, which usually leads to better results.

Let’s look closer at beam search through the previous example:

John gave Mary a puppy.

Say we are predicting word by word, and initially we have the following:

John ________ ________ ________ ________.

Let’s assume hypothetically that our LSTM produces the example sentence using beam search. Then, the probabilities for each word might look like what we see in the figure below. Let’s assume beam length b=2b = 2, and we’ll consider the n=3n = 3 best candidates at each stage of the search.

The search tree would look like the following figure:

Get hands-on with 1400+ tech skills courses.