LSTM Model

Learn how to make and train neural networks using LSTM with JAX and Flax.

Define LSTM model in Flax

We are now ready to define the LSTM model in Flax. To design LSTMs in Flax, we use the LSTMCell or the OptimizedLSTMCell. The OptimizedLSTMCell is the efficient LSTMCell.

The LSTMCell.initialize_carry function is used to initialize the hidden state of the LSTM cell. It expects:

  • A random number.
  • The batch dimensions.
  • The number of units.

Press + to interact
LSTM architecture
LSTM architecture

Let’s use the setup method to define the LSTM model. The LSTM contains the following layers:

  • An Embedding layer with the same number of features and length as defined in the vectorization layer.
  • LSTM layers that pass data in one direction as specified by the reverse argument.
  • A couple of Dense layers.
  • A final Dense output layer.
Press + to interact
from flax import linen as nn
class LSTMModel(nn.Module):
def setup(self):
self.embedding = nn.Embed(max_features, max_len)
lstm_layer = nn.scan(nn.OptimizedLSTMCell,
variable_broadcast="params",
split_rngs={"params": False},
in_axes=1,
out_axes=1,
length=max_len,
reverse=False)
self.lstm1 = lstm_layer()
self.dense1 = nn.Dense(256)
self.lstm2 = lstm_layer()
self.dense2 = nn.Dense(128)
self.lstm3 = lstm_layer()
self.dense3 = nn.Dense(64)
self.dense4 = nn.Dense(2)
@nn.remat
def __call__(self, x_batch):
x = self.embedding(x_batch)
carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=128)
(carry, hidden), x = self.lstm1((carry, hidden), x)
x = self.dense1(x)
x = nn.relu(x)
carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=64)
(carry, hidden), x = self.lstm2((carry, hidden), x)
x = self.dense2(x)
x = nn.relu(x)
carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=32)
(carry, hidden), x = self.lstm3((carry, hidden), x)
x = self.dense3(x)
x = nn.relu(x)
x = self.dense4(x[:, -1])
return nn.log_softmax(x)

We import the linen module from the flax library as nn to define the LSTM model. In the code above, we define the LSTMModel class using nn.Module. Inside this class:

  • Lines 4–19: We define the setup() function to define the layers and components of the model. Inside this function:

    • Line 5: We define the Embedding layer to map the discrete input to the continuous vector.

    • Lines 6–12: We define the LSTM layer as lstm_layer, where we call the scan() method of nn to set up the LSTM layer with the given configurations.

    • Lines 13–14: We call the defined lstm_layer() to define the LSTM layer, followed by the Dense layer with 256 units.

    • Lines 15–19: Similarly, we define another LSTM layer, a Dense layer with 128 layers, the third layer of LSTM, and a Dense layer with 64 units. Lastly, we define the Dense layer with two units for the binary classification.

  • Line 21: We apply the @nn.remat decorator to the __call__() function for memory optimization and numerical stability. The @nn.remat decorator saves memory when using LSTMs to compute long sequences.

  • Lines 22–41: We define the __call__() function to implement the forward pass. Inside this function:

    • Line 23: We pass the given input through the first layer (the Embedding layer) of the model.

    • Lines 25–28: We call the nn.OptimizedLSTMCell.initialize_carry method to initialize the carry and hidden states and pass the output of the previous layer through the first LSTM layer, followed by the first Dense layer. Lastly, we apply the ReLU activation function.

    • Lines 30–38: Similarly, we pass the output of the previous layer through the subsequent LSTM and Dense layers and apply the ReLU activation function.

    • Lines 40–41: We pass the output of the previous layer through the last Dense layer. Lastly, we apply the LogSoftmax activation and return the output.

We apply the ...