Machine Learning Deep Learning

An Introduction to Recurrent Neural Networks & LSTMs

Peter Foy
Peter Foy

In this introductory guide we'll discuss recurrent neural networks (RNNs) and long short-term memory (LSTM) networks.

Let's jump into it:

A recurrent neural network attempts to model time-based or sequence-based data.

A few applications of recurrent neural networks include natural language processing (NLP), predicting stock prices, predicting energy demand.

To get a bit more technical, recurrent neural networks are designed to learn from sequences of data by passing the hidden state from one step in the sequence, to the next step, and combining this with the input.

This bring us to Long Short-Term Memory networks:

An LSTM network is a type of RNN that uses special units as well as standard units.

So what are these special units?

LSTM units include a 'memory cell' that can keep information in memory for long periods of time.

LSTMs are particularly useful when the neural network needs to switch between remembering recent things, and memory from a long time ago.

RNNs vs. LSTMs

Let's say we have a regular neural network that is used for image recognition.

If we pass as image of a dog to the neural network - it will ideally output a high probability of being a dog, and maybe a small percentage chance of being a wolf, and a smaller chance of being a cat.

But what if the image was actually a wolf? How would the neural network know?

Now let's say that we have a sequence of images, say on a nature TV show, and the previous images were a wolf, a bear, and a fox.

What we do is analyze the images with the same copy of the neural network, but we use the output of the neural network as part of the the input of the next one.

This can improve our results.

To do this mathematically, we just combine the vectors in a linear function which will then be combined with an activation function - which could be either sigmoid or hyperbolic tan.

By doing it this way, the final neural network will know that the TV show is about wild animals that live in forests, and can use this information to predict the image is a wolf as opposed to a dog.

This is essentially how recurrent neural networks work.

The problem with recurrent neural networks is that the memory of the network is generally short-term memory.

For example if in between the bear and the wolf we had a flower and a tree (which could be thought of as either domestic of wild), the network would have a hard time remembering the significance of the bear and fox.

In short, RNNs have a hard time storing long term memory.

This is where long short-term memory, or LSTM networks save the day.

To recap, this is how an RNN works:

  • Memory comes in and merges with a current event
  • The output comes out as a prediction of what the input is
  • The output is also part of the input of the next iteration of the neural network

In a similar way, an LSTM works as follows:

  • It keeps track not just of short term memory, but also of long term memory
  • In every step of the sequence, the long and short term memory in the step get merged
  • From this, we get a new long term memory, short term memory, and prediction

By doing it this way, the network can remember things from a long time ago.

Now Let's look at the architecture of LSTMs.

Basics of LSTM Networks

In our example of the nature TV show, we have:

  • Long term memory about nature and forest animals
  • Short term memory about flowers and trees
  • An event - the image that could be a dog or wolf

We want to combine these things to create a prediction about what our image is.

The long term memory gives us a hint that the image should favor a wolf over a dog.

We also want all three variables to help us update the long term and short term memory of the network.

To define this, the architecture of the LSTM contains several gates:

  • A forget gate
  • A learn gate
  • A remember gate
  • And a use gate

Here's how they work together:

  • The long term memory goes to the forget gate, where it forgets everything that's not useful
  • The short term memory and the event are combined in the learn gate
  • The long term memory we haven't forgotten plus the new information we learned get joined in the remember gate - which outputs an updated long term memory
  • The use gate decides what information we use from what was already know plus what we just learned to make a prediction
  • The output becomes both the prediction and the new short term memory

Now let's dive a bit deeper into the different gates.

The Forget Gate

  • Takes the long term memory and decides what part to keep and what to forget

How does this work mathematically?

  • The long-term memory (LTM) from time t-1 is multiplied by a forget factor ft.
  • The forget factor is calculated with the short-term memory (STM) and the event information Et.
  • We run a small one layer neural network with a linear function combined with the sigmoid function to calculate the forget factor.

The Remember Gate

  • The remember gate takes the output from the Forget Gate and from the Learned Gate and we add them together.

The Use Gate

  • Also called the output gate, this uses the long-term memory that came from the Forget Gate and the short-term memory from the Learned Gate to come up with a new short-term memory and an output (these are the same thing).

Here's how we do this mathematically:

  • It applies a small neural network using the tanh activation function on the long-term memory
  • It applies another small neural network on short-term memory and the events using the sigmoid activation function
  • As a final step is multiplies these two together to get the new output

For a more in depth explanation of the mathematics of LSTM Networks I recommend this great video from Siraj Raval.

Character-wise RNNs

Character-wise RNNs are networks that learn text one character at a time, and generate new text one character at a time.

As an example we may want to generate new News Article Headlines, as this paper from Stanford discusses.

Sequence Batching

One of the hardest parts of building recurrent neural networks can be getting the batches right.

Let's walk through how batching works for RNNs:

  • With RNNs we're training on sequences of data, such as text, audio, or stock prices
  • By splitting the sequences into smaller sequences we can use matrix operations to improve the efficiency of training
  • For example if we have a sequence of numbers from 0-10, we can either pass it in as one sequence, or we could split it into 2 sequences i.e. [0-4] and [5-10]
  • The batch size corresponds to the number of sequences, so here the batch size would be 2

We also choose the length of the sequences we feed to the network, for example we could use the first 3 numbers of the sequence.

We can retain the hidden state from one batch and use it for the next one, thus the sequence information is transferred across batches for each mini-sequence.

Summary: Recurrent Neural Networks & LSTMs

  • RNNs are designed to learn from sequences of data by passing the hidden state from one step in the sequence to the next step and combining this with the input.
  • LSTM networks are a type of RNN that use special units as well as standard units.
  • LSTM units include a "memory cell" that can keep information in memory for long periods of time.
  • LSTMs are particularly useful when our neural network needs to switch between remembering recent features, and features from a long time ago.

Further Resources


Coding Tutorials:

Join the conversation.