‹ Notes

Understanding RNN's

Apr 23, 2025

Recursive neural networks are neural networks which apply recursion in their structure.

Recursion refers to a function which calls itself. The classic example is fibonacci sequences - fib(n) = fib(n-1) + fib(n-2), where fib(0) = 0 and fib(1) = 1.

Take fib(3) and let us unroll the recursion stack:

fib(3)
= (fib(2) + fib(1))
= ((fib(1) + fib(0)) + fib(1)) 

The definition of the RNN function is the following:

  • Output: y = h * Why + by
  • Function: h = tanh(xt * Wxh + h * Whh + bh)
    • Note the recurrence is in h.
  • Weights
    • Wxh: hidden weights for x
    • Whh: hidden weights for h
    • Why: weights for y
  • Biases: bh and by

Expressing this in imperative code, it looks something like:

from torch import Tensor, zeros, tanh

def rnn(x -> Tensor):
    h = zeros(hidden_size)
    for xt in x:
        h = tanh((xt @ Wxh) + (h @ Whh) + bh)
    y = h @ Why + by
    return y

The recurrence occurs in h - this is akin to calling the function recursively:

def fib(n):
    if n == 0 or n == 1:
        return n
    return fib(n-1) + fib(n-2)

# akin to.

def rnn(x):
    if len(x) == 0:
        # h_initial = 0
        return zeros(hidden_size)
    xt = x.pop()
    return tanh((xt @ Wxh) + (rnn(x) @ Whh) + bh)

def rnn_out(x):
    return rnn(x) @ Why + by

Unrolling an RNN, this recursion begin to become clearer:

# The unrolled recurrence of a vanilla RNN cell.
# For a sequence x = [0..t].
# 
h = f0(x) = torch.tanh(x0 * Wxh + h0      * Whh + bh)
h = f1(x) = torch.tanh(x1 * Wxh + f0(x)   * Whh + bh)
h = f2(x) = torch.tanh(x2 * Wxh + f1(x)   * Whh + bh)
...
h = ft(x) = torch.tanh(xt * Wxh + ft-1(x) * Whh + bh)
y = h @ Why + by

Distilling this:

  • h is the output of the hidden layer of the RNN.
  • y is the final output, weighted by the Why weights and by bias.
  • Wxh weights the current input $x_{t}$
  • Whh weights the previous output, h
  • the definition for h is translated as:
    • ft(x) = (x[t] * input_weight) + (previous_invocation * previous_weights_learnt) + bias

Appendix.

RNN in Pytorch

import torch
import torch.nn.functional as F
import torch.nn as nn

# Config
input_size = 1
hidden_size = 16
seq_len = 3
lr = 0.01
epochs = 200

# Data
prices = [100, 102, 105, 107, 110, 112, 115]
def create_sequences(prices, seq_len):
    xs, ys = [], []
    for i in range(len(prices) - seq_len):
        xs.append(prices[i:i+seq_len])
        ys.append(prices[i+seq_len])
    return torch.tensor(xs).float().unsqueeze(-1), torch.tensor(ys).float().unsqueeze(-1)

x, y = create_sequences(prices, seq_len)

# Parameters
Wxh = torch.randn(input_size, hidden_size, requires_grad=True) * 0.01
Whh = torch.randn(hidden_size, hidden_size, requires_grad=True) * 0.01
bh  = torch.zeros(hidden_size, requires_grad=True)

Why = torch.randn(hidden_size, 1, requires_grad=True) * 0.01
by  = torch.zeros(1, requires_grad=True)

# RNN forward
def rnn_forward(x_seq):
    h = torch.zeros(hidden_size)
    for t in range(x_seq.size(0)):
        # get the stock price at time t
        xt = x_seq[t]
        # hidden state = stock price * 
        h = torch.tanh(xt @ Wxh + h @ Whh + bh)
    y_pred = h @ Why + by
    return y_pred

class CustomRNNLayer(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        # The hidden state:
        # Wxh - lookup in xt
        # Whh - lookup in previous hidden state t-1
        # bh - bias for the hidden state
        # Why - weight for the hidden state
        self.Wxh = nn.Parameter(torch.randn(input_size, hidden_size) * 0.01)
        self.Whh = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.01)
        self.bh = nn.Parameter(torch.zeros(hidden_size))
        self.by = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        h = torch.zeros(batch_size, self.Whh.size(0), device=x.device)
        
        # The loop recursively applies the function f(h, xt) = torch.tanh(xt @ self.Wxh + h @ self.Whh + self.bh)
        # to update the hidden state h at each time step t
        for t in range(seq_len):
            # Get the stock price at t
            xt = x[:, t, :]
            
            h = torch.tanh(xt @ self.Wxh + h @ self.Whh + self.bh)
        # The final output y is obtained by applying the function g(h) = h @ self.Why + self.by
        # y = h @ self.Why + self.by
        return y

# Training loop
for epoch in range(epochs):
    total_loss = 0
    for i in range(len(x)):
        y_pred = rnn_forward(x[i])
        loss = F.mse_loss(y_pred, y[i])
        loss.backward()
        
        # Manual SGD
        for param in [Wxh, Whh, bh, Why, by]:
            param.data -= lr * param.grad
            param.grad.zero_()

# Predict
with torch.no_grad():
    test_seq = torch.tensor(prices[-seq_len:]).float().unsqueeze(-1)
    pred = rnn_forward(test_seq)
    print(pred.item())