LSTM Model
Learn how to make and train neural networks using LSTM with JAX and Flax.
Define LSTM model in Flax
We are now ready to define the LSTM model in Flax. To design LSTMs in Flax, we use the LSTMCell
or the OptimizedLSTMCell
. The OptimizedLSTMCell
is the efficient LSTMCell
.
The LSTMCell.initialize_carry
function is used to initialize the hidden state of the LSTM cell. It expects:
- A random number.
- The batch dimensions.
- The number of units.
Let’s use the setup
method to define the LSTM model. The LSTM contains the following layers:
- An Embedding layer with the same number of features and length as defined in the vectorization layer.
- LSTM layers that pass data in one direction as specified by the
reverse
argument. - A couple of Dense layers.
- A final Dense output layer.
from flax import linen as nnclass LSTMModel(nn.Module):def setup(self):self.embedding = nn.Embed(max_features, max_len)lstm_layer = nn.scan(nn.OptimizedLSTMCell,variable_broadcast="params",split_rngs={"params": False},in_axes=1,out_axes=1,length=max_len,reverse=False)self.lstm1 = lstm_layer()self.dense1 = nn.Dense(256)self.lstm2 = lstm_layer()self.dense2 = nn.Dense(128)self.lstm3 = lstm_layer()self.dense3 = nn.Dense(64)self.dense4 = nn.Dense(2)@nn.rematdef __call__(self, x_batch):x = self.embedding(x_batch)carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=128)(carry, hidden), x = self.lstm1((carry, hidden), x)x = self.dense1(x)x = nn.relu(x)carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=64)(carry, hidden), x = self.lstm2((carry, hidden), x)x = self.dense2(x)x = nn.relu(x)carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=32)(carry, hidden), x = self.lstm3((carry, hidden), x)x = self.dense3(x)x = nn.relu(x)x = self.dense4(x[:, -1])return nn.log_softmax(x)
We import the linen
module from the flax
library as nn
to define the LSTM model. In the code above, we define the LSTMModel
class using nn.Module
. Inside this class:
Lines 4–19: We define the
setup()
function to define the layers and components of the model. Inside this function:Line 5: We define the Embedding layer to map the discrete input to the continuous vector.
Lines 6–12: We define the LSTM layer as
lstm_layer
, where we call thescan()
method ofnn
to set up the LSTM layer with the given configurations.Lines 13–14: We call the defined
lstm_layer()
to define the LSTM layer, followed by the Dense layer with256
units.Lines 15–19: Similarly, we define another LSTM layer, a Dense layer with
128
layers, the third layer of LSTM, and a Dense layer with64
units. Lastly, we define the Dense layer with two units for the binary classification.
Line 21: We apply the
@nn.remat
decorator to the__call__()
function for memory optimization and numerical stability. The@nn.remat
decorator saves memory when using LSTMs to compute long sequences.Lines 22–41: We define the
__call__()
function to implement the forward pass. Inside this function:Line 23: We pass the given input through the first layer (the Embedding layer) of the model.
Lines 25–28: We call the
nn.OptimizedLSTMCell.initialize_carry
method to initialize thecarry
andhidden
states and pass the output of the previous layer through the first LSTM layer, followed by the first Dense layer. Lastly, we apply the ReLU activation function.Lines 30–38: Similarly, we pass the output of the previous layer through the subsequent LSTM and Dense layers and apply the ReLU activation function.
Lines 40–41: We pass the output of the previous layer through the last Dense layer. Lastly, we apply the LogSoftmax activation and return the output.
We apply the ...