‹ Notes

Test time training and RNN's

Apr 23, 2025

Background reading: Understanding RNN’s

Comparing RNN’s and attention.

RNN’s and attention are both sequence modelling techniques.

Where N is length of the sequence, RNN’s exhibit O(N) cost, and attention is O(N2) cost (as it multiplies QK to lookup into V, where QKV are all projections of x).

Through a compression lens, RNN’s compress data differently to attention(?). Whereas attention can “shift back and forth”, RNN’s can only weight the outputs of the previous invocation.

an RNN layer’s ability to remember long context is limited by the amount of information its hidden state can store


Test-time training.

Test-time training (paper, video paper) is a new technique, which can be explained as:

  • a formulation of an RNN where the hidden values (the recursions) h are weights of a neural network W
  • a new type of sequence modelling layer:
    • a neural network inside a neural network
    • the layer’s forward function performs training of the inner network, has its own learned learning rate and weight initialisation.
    • the layer’s output is one inference of this inner network.
    • the outer network’s training still takes the loss, which includes differentiating the forward function of the TTT layer, in itself, being a learning function. hence they refer to “gradients of gradients”.

With the definitions of an RNN:

  • initial state: s0 = vector()
  • update rule: s_t = g((Whh * s_t-1) + (Whx * xt))
  • output rule: z_t = Why * s_t + Wyx

For a TTT layer:

  • initial state: W0 = f.params()
    • the initial state is the neural network’s weights
  • update rule: Wt = Wt−1 − η∇ℓ(Wt−1; xt)
    • update the current weight to the previous weights minus the learning rate times the gradient
  • output rule: zt = f(xt ; Wt)
    • output is one evaluation/inference of the network

The TTT layer is an RNN where the hidden units are weights of another neural network, and the output rule is an inference of that neural net.

The TTT layer performs the recurrence by running one training step on the weights.

The loss function so far is undefined - what is our learning objective for this TTT layer? It could be supervised, self-supervised, etc.

Basic objective function - reconstruction from noise.

In the TTT for video paper, they mention using an objective function that mirrors a reconstruction task like in denoising autoencoders:

L(W;xt)=||f(xt^;W)xt||2

Translating this:

  • process xt into corrupted input xt^ (e.g. blur)
  • loss = ||f(xt^;W)xt||2

Objective function - low rank projection.

Maybe all information of xt^ and the reconstruction label xt are necessary, so we can use a low rank projection instead. The low rank projections are referred to by their learned matrices θQ,θK,θV for the xt, xt^, and zt respectively.

As low rank projections change the dimension, the output rule is modified too:

zt=f(θQxt;Wt)

There are a few more techniques they apply which are listed in the next section.

TTT-MLP.

The TTT they employ is a multi-layer perceptron.


Techniques.

Learnable W0.

The TTT initialization W0 is shared between all sequences. Empirically, we observe that learning W0 significantly improves training stability

Learnable η

The learning rate is usually the most important hyper-parameter for gradient descent, so we experiment with learning the inner-loop learning rate η in Equation 6 as part of the outer loop. We make η a function of the input token (therefore different across time) for additional flexibility.

Concretely, we design η(x) = ηbase σ (θlr · x), where the learnable vector θlr is an outer-loop parameter, σ is the sigmoid function, and the scalar ηbase is the base learning rate, set to 1 for TTT-Linear and 0.1 for TTT-MLP.

Inner loop mini batch.

Given a TTT layer is a recurrent relation, where hi involves computing hi1, it cannot naively parallelized across tokens in a sequence.

They follow the standard mini batching you find in neural networks and apply it here.

  1. Normal (fully sequential) update

    • You have weights (W).
    • For each new data point (x_t), you do
      WWη,(W;xt).
    • You can’t do two of these at once, because the second gradient needs the updated (W).
  2. Write all updates at once

    • If you start from (W_0) and do (T) steps, you get
      WT=W0ηs=1TGs,Gs=(Ws1;xs).
    • Now instead of “update–compute–update–compute…”, you can “compute all (G_s)” then subtract their sum.
  3. Mini-batch trick

    • Pick a batch size (b) (e.g.\ 16).
    • Split your (T) points into blocks of (b).
    • For block 1 (steps 1…(b)): use (W_0) as the base. Compute in parallel
      G1=(W0;x1),;,;Gb=(W0;xb).
    • Then apply them in order by doing a cumulative sum:
      W1=W0η,G1,;W2=W1η,G2,;,;Wb=Wb1η,Gb.
    • For block 2 (steps (b+1)…(2b)): use (W_b) as base, compute those (b) gradients in parallel, then cumsum, and so on.
  4. Why it helps

    • You do (b) gradient computations at once instead of one by one.
    • Larger (b): more parallel speed, but gradients get a bit “stale” (they all use the same base).
    • Smaller (b): closer to true sequential updates but less speedup.
    • They found (b=16) a good balance.

Gating.

Naively inserting TTT layers into a pre-trained network would dramatically worsen its predictions at the beginning of fine-tuning, when the TTT layers are randomly initialized. To avoid this degradation, we gate TTT with a learned vector αRd following standard practice:

gate(TTT,X;α)=tanh(α)TTT(X)+X

We initialize all values in α to 0.1, so the values in tanh(α) are close to 0 (≈ 0.1) at the beginning of fine-tuning.

Bi-direction.

Diffusion models, including CogVideo-X, are non-causal, meaning that an output token zt can condition on all of x1, . . . , xT instead of only the past tokens x1, . . . , xt.

Standard trick - bi-direction.

TTT′(X)=rev(TTT(rev(X))), where rev performs reverse.


Practice.

This is how they modify traditional transformer architectures.

Translated to code:

# Normal backbone.
Xl = self_attn(LN(X))
Y = Xl + X # residual

# Modified backbone.
Xl = self_attn(LN(X))
Z = gate(TTT, Xl, a)
Zl = gate(TTT1, Z, b)
Y = Zl + X

Notes from the TTT paper.

  • The difficulty with long context is inherent to the very nature of RNN layers: Unlike self-attention, RNN layers have to compress context into a hidden state of fixed size
  • Tokens later in a sequence should be easier to predict on average, since they condition on more information. This is indeed the case for Transformer, whose average perplexity at each token index decreases throughout its 32k context. In contrast, the same metric plateaus for Mamba after 16k.
  • On one hand, the main advantage of RNNs (vs. Transformers) is their linear (vs. quadratic) complexity. This asymptotic advantage is only realized in practice for long context, which according to Figure 12 is after 8k. On the other hand, once context is long enough, existing RNNs such as Mamba struggle to actually take advantage of the extra information being conditioned on

Appendix