Writing a Custom LSTM Cell in Pytorch

Implement an LSTM cell from scratch in Pytorch.

We'll cover the following

Creating an LSTM network in Pytorch is pretty straightforward.

import torch.nn as nn
## input_size -> N in the equations
## hidden_size -> H in the equations
layer = nn.LSTM(input_size= 10, hidden_size=20, num_layers=2)

Note that the number of layers is the number of cells that are connected. So this network will have LSTM cells connected together. We will see how in the next lesson. For now, we will focus on the simple LSTM cell based on the equations.

It is ideal to build an LSTM cell entirely from scratch. We have our equations for each gate, so all we have to do is transform them into code and connect them. As an example, a code template as well as the input gate will be provided and you will have to do the rest.

The originally proposed equations that we described are:

it=σ(Wxixt+Whiht1+Wcict1+bi)(1)i_t = \sigma( W_{xi} x_t + W_{hi} {h}_{t-1} + {W}_{ci} {c}_{t-1} + {b}_i) \quad\quad(1)

ft=σ(Wxfxt+Whfht1+Wcfct1+bf)(2){f}_t = \sigma( {W}_{xf} {x}_t + {W}_{hf} {h}_{t-1} + {W}_{cf} {c}_{t-1} + {b}_f) \quad\quad(2)

ct=ftct1+ittanh(Wxcxt+Whcht1+bc)(3){c}_t = {f}_t \odot {c}_{t-1} + {i}_t \odot tanh( {W}_{xc} x_t + {W}_{hc} {h}_{t-1} + {b}_c ) \quad\quad(3)

ot=σ(Wxoxt+Wh0ht1+Wcoct+bo)(4){o}_t = \sigma( {W}_{xo} {x}_t + {W}_{h0} {h}_{t-1} + {W}_{co} {c}_{t} + {b}_o) \quad\quad(4)

ht=ottanh(ct)(5){h}_t = {o}_t \odot tanh({c}_t) \quad\quad(5)

Simplification of LSTM equations

However, modern deep learning frameworks use a slightly simpler version of the LSTM. Actually, they disregard ct1{c}_{t-1} from Equation (1) and (2). And you will do the same. This results in a less complex model that is easier to optimize. Thus, we will implement the following equations in this exercise:

it=σ(Wxixt+Whiht1+bi)(1)i_t = \sigma( W_{xi} x_t + W_{hi} {h}_{t-1} + {b}_i) \quad\quad(1)

ft=σ(Wxfxt+Whfht1+bf)(2){f}_t = \sigma( {W}_{xf} {x}_t + {W}_{hf} {h}_{t-1} + {b}_f) \quad\quad(2)

ct=ftct1+ittanh(Wxcxt+Whcht1+bc)(3){c}_t = {f}_t \odot {c}_{t-1} + {i}_t \odot tanh( {W}_{xc} x_t + {W}_{hc} {h}_{t-1} + {b}_c ) \quad\quad(3)

ot=σ(Wxoxt+Wh0ht1+Wcoct+bo)(4){o}_t = \sigma( {W}_{xo} {x}_t + {W}_{h0} {h}_{t-1} + {W}_{co} {c}_{t} + {b}_o) \quad\quad(4)

ht=ottanh(ct)(5){h}_t = {o}_t \odot tanh({c}_t) \quad\quad(5)

If this exercise feels too difficult, don’t be discouraged. It is. It may seem that it is simply an implementation of a few equations, but it is not. Feel free to give it a shot but also to move on if you get stuck.

Note that the code below will produce an error when executing for the first time. Don’t be alarmed by it. You can continue with the exercise as you will normally do.

Get hands-on with 1300+ tech skills courses.