LSTM Back-Propagation — the Math Behind the Scenes
LSTM (Long Short Term Memory) is an RNN-based network architecture mainly used for sequence analysis in the domain of Deep Learning. This could be used to create language models, understand the sentiment of a sentence and learn patterns in sentences, audio clips and other sequences.
This post assumes that you have a basic idea of how LSTM works, and is targeted towards those curious to know what the actual math behind the back-propagation used in LSTM is. Nowadays, libraries like Tensorflow, PyTorch have made it convenient, simple and let you just design the forward propagation, sit back and admire the model training itself without you needing to get your hands dirty and define the back-propagation algorithm.
The notations used here might be different from the ones used in standard academic literature, but are in continuation with those used in Andrew NG’s course and his assignment. So, this will be a “Eureka moment” especially for those of you who went on to attempt to decipher the calculus behind those magical equations in the assignment but just couldn’t get it to look exactly the way they were mentioned.
In all the equations, x indicates matrix multiplication, * indicates element-wise multiplication
Nonetheless, as a precursor, I’ll start by showing the diagram of an LSTM unit similar to the one mentioned in the assignment.
Derivation of Back-Propagation Equations
If you have taken the Sequence Models Course by DeepLearning.ai, you will realise that these equations are exactly the same as mentioned in the assignment notebook for LSTM.
Note that, in reality, it is still just 1 LSTM cell which is repeatedly used across multiple timestamps in a sequence. So, essentially at each timestep, the weight and bias parameters are accumulated (added) and when the entire back-propagation for all timesteps is over, these parameters are updated.
The main inspiration of the diagram comes from the course. However, I derived all the equations in my Jupyter Notebook. I know some of these images are not very clear. I had to use all images instead of text because Medium articles do not support LaTex expressions. So, I tried to cover all the steps in the derivation and captured snippets from my Notebook.