Understanding RNN's
Apr 23, 2025
Recursive neural networks are neural networks which apply recursion in their structure.
Recursion refers to a function which calls itself. The classic example is fibonacci sequences - fib(n) = fib(n-1) + fib(n-2)
, where fib(0) = 0
and fib(1) = 1
.
Take fib(3)
and let us unroll the recursion stack:
fib(3)
= (fib(2) + fib(1))
= ((fib(1) + fib(0)) + fib(1))
The definition of the RNN function is the following:
- Output:
y = h * Why + by
- Function:
h = tanh(xt * Wxh + h * Whh + bh)
- Note the recurrence is in
h
.
- Note the recurrence is in
- Weights
Wxh
: hidden weights forx
Whh
: hidden weights forh
Why
: weights fory
- Biases:
bh
andby
Expressing this in imperative code, it looks something like:
from torch import Tensor, zeros, tanh
def rnn(x -> Tensor):
h = zeros(hidden_size)
for xt in x:
h = tanh((xt @ Wxh) + (h @ Whh) + bh)
y = h @ Why + by
return y
The recurrence occurs in h
- this is akin to calling the function recursively:
def fib(n):
if n == 0 or n == 1:
return n
return fib(n-1) + fib(n-2)
# akin to.
def rnn(x):
if len(x) == 0:
# h_initial = 0
return zeros(hidden_size)
xt = x.pop()
return tanh((xt @ Wxh) + (rnn(x) @ Whh) + bh)
def rnn_out(x):
return rnn(x) @ Why + by
Unrolling an RNN, this recursion begin to become clearer:
# The unrolled recurrence of a vanilla RNN cell.
# For a sequence x = [0..t].
#
h = f0(x) = torch.tanh(x0 * Wxh + h0 * Whh + bh)
h = f1(x) = torch.tanh(x1 * Wxh + f0(x) * Whh + bh)
h = f2(x) = torch.tanh(x2 * Wxh + f1(x) * Whh + bh)
...
h = ft(x) = torch.tanh(xt * Wxh + ft-1(x) * Whh + bh)
y = h @ Why + by
Distilling this:
h
is the output of the hidden layer of the RNN.y
is the final output, weighted by theWhy
weights andby
bias.Wxh
weights the current input $x_{t}$Whh
weights the previous output,h
- the definition for
h
is translated as:ft(x) = (x[t] * input_weight) + (previous_invocation * previous_weights_learnt) + bias
Appendix.
RNN in Pytorch
import torch
import torch.nn.functional as F
import torch.nn as nn
# Config
input_size = 1
hidden_size = 16
seq_len = 3
lr = 0.01
epochs = 200
# Data
prices = [100, 102, 105, 107, 110, 112, 115]
def create_sequences(prices, seq_len):
xs, ys = [], []
for i in range(len(prices) - seq_len):
xs.append(prices[i:i+seq_len])
ys.append(prices[i+seq_len])
return torch.tensor(xs).float().unsqueeze(-1), torch.tensor(ys).float().unsqueeze(-1)
x, y = create_sequences(prices, seq_len)
# Parameters
Wxh = torch.randn(input_size, hidden_size, requires_grad=True) * 0.01
Whh = torch.randn(hidden_size, hidden_size, requires_grad=True) * 0.01
bh = torch.zeros(hidden_size, requires_grad=True)
Why = torch.randn(hidden_size, 1, requires_grad=True) * 0.01
by = torch.zeros(1, requires_grad=True)
# RNN forward
def rnn_forward(x_seq):
h = torch.zeros(hidden_size)
for t in range(x_seq.size(0)):
# get the stock price at time t
xt = x_seq[t]
# hidden state = stock price *
h = torch.tanh(xt @ Wxh + h @ Whh + bh)
y_pred = h @ Why + by
return y_pred
class CustomRNNLayer(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
# The hidden state:
# Wxh - lookup in xt
# Whh - lookup in previous hidden state t-1
# bh - bias for the hidden state
# Why - weight for the hidden state
self.Wxh = nn.Parameter(torch.randn(input_size, hidden_size) * 0.01)
self.Whh = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.01)
self.bh = nn.Parameter(torch.zeros(hidden_size))
self.by = nn.Parameter(torch.zeros(1))
def forward(self, x):
batch_size, seq_len, _ = x.size()
h = torch.zeros(batch_size, self.Whh.size(0), device=x.device)
# The loop recursively applies the function f(h, xt) = torch.tanh(xt @ self.Wxh + h @ self.Whh + self.bh)
# to update the hidden state h at each time step t
for t in range(seq_len):
# Get the stock price at t
xt = x[:, t, :]
h = torch.tanh(xt @ self.Wxh + h @ self.Whh + self.bh)
# The final output y is obtained by applying the function g(h) = h @ self.Why + self.by
# y = h @ self.Why + self.by
return y
# Training loop
for epoch in range(epochs):
total_loss = 0
for i in range(len(x)):
y_pred = rnn_forward(x[i])
loss = F.mse_loss(y_pred, y[i])
loss.backward()
# Manual SGD
for param in [Wxh, Whh, bh, Why, by]:
param.data -= lr * param.grad
param.grad.zero_()
# Predict
with torch.no_grad():
test_seq = torch.tensor(prices[-seq_len:]).float().unsqueeze(-1)
pred = rnn_forward(test_seq)
print(pred.item())