Relational Recurrent Neural Networks

Santoro, Faulkner, Raposo, Rae, Chrzanowski, Weber, Wierstra, Vinyals, Pascanu, Lillicrap (DeepMind) Β· NeurIPS 2018 Β· arXiv 1806.01822

TL;DR

Standard LSTMs compress everything into a single hidden vector, making it impossible to reason about relationships between distinct stored facts. Relational Memory Core (RMC) maintains multiple memory slots and lets them communicate via multi-head self-attention at every step β€” explicitly modeling which memories interact. On tasks that require tracking multiple independent variables (program evaluation, spatial navigation, language modeling), RMC consistently outperforms LSTM and DNC.

1. Beyond Single-Vector Memory

A standard LSTM at time step t maintains a hidden state h_t and cell state c_t β€” both single vectors. Every piece of information the model needs to remember is entangled inside these vectors. There is no mechanism to separately track, say, "variable A holds value 3" and "variable B holds value 7" and then reason about their relationship without conflating them.

This is a fundamental limitation for relational reasoning β€” tasks where the answer depends on comparing or combining multiple distinct facts held in memory simultaneously. The LSTM must implicitly learn to encode and disentangle this structure from scratch, with no architectural support.

2. Relational Memory Core Architecture

RMC replaces the single hidden vector with a memory matrix M. At each time step, M holds h independent memory slots, each a d-dimensional vector:

M∈RhΓ—dM \in \mathbb{R}^{h \times d}

Here h is the number of memory slots (a hyperparameter, typically 4–8) and d is the slot dimension. The total memory capacity is h Γ— d, but crucially the structure is explicit: slot i and slot j are distinct, and the model can attend to them independently.

The input x_t at each time step is first concatenated with the memory matrix to form an augmented memory, ensuring the current input can interact with all existing memories during the attention step:

M0t=concat(Mtβˆ’1, xt)M_0^t = \text{concat}(M^{t-1},\, x_t)

This augmented matrix is then processed by multi-head self-attention, and the result is passed through an LSTM to produce the updated memory M^t.

3. Multi-Head Self-Attention Over Memories

The core relational computation is standard scaled dot-product attention applied to the rows of M_0^t. Each memory slot acts as both a query and a key-value pair β€” every slot can attend to every other slot:

Q=M0tWQ,K=M0tWK,V=M0tWVQ = M_0^t W_Q, \quad K = M_0^t W_K, \quad V = M_0^t W_V
A~=softmax ⁣(QK⊀dk)V\tilde{A} = \text{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V

where W_Q, W_K, W_V are learned projection matrices and d_k is the key dimension (used for scaling to prevent vanishing gradients in the softmax). This is the same attention formula as in the Transformer β€” the key difference is that it is applied to memory slots over time, not to token positions across a sequence.

Multi-head attention runs this in parallel for several heads with independent projection matrices, then concatenates the results. This allows different heads to capture different types of relational interactions simultaneously β€” one head might track spatial relationships while another tracks value comparisons.

4. Memory Read and Write

After multi-head attention produces A-tilde, each memory slot is updated using an LSTM. The LSTM acts as the write mechanism β€” it decides how much of the attended information to incorporate into each slot, using its gating structure to control forgetting and integration:

Mt=LSTM(Mtβˆ’1, A~)M^t = \text{LSTM}(M^{t-1},\, \tilde{A})

Concretely, the LSTM is applied row-wise: each memory slot m_i^{t-1} is updated using the corresponding attended representation a-tilde_i as input. The LSTM's forget gate learns when to discard old slot content, the input gate learns when to write new information, and the output gate controls what is exposed for reading at this step.

Reading from memory happens implicitly: the output of the entire RMC at time step t is typically the concatenation of all h memory slots, or a weighted combination. Downstream modules (policy networks, language model heads, etc.) receive this full memory state and can selectively use different slots.

5. Program Evaluation Task: Tracking Multiple Variables

The program evaluation task is a synthetic benchmark designed to stress-test relational memory. A model is shown a sequence of variable assignments and operations (e.g., "A = 3; B = 7; C = A + B; output C") and must output the correct value. This requires:

  • Storing each variable's value in a distinct location (a slot, not a blended vector)
  • Retrieving and combining specific variables when an operation refers to them
  • Updating only the target variable without corrupting others

RMC can naturally assign one memory slot per variable. When the model reads "C = A + B", the attention mechanism can retrieve slots for A and B, compute their sum, and write the result into slot C β€” all within the attention + LSTM update. An LSTM has no such structure and must learn this implicitly, leading to degraded performance as the number of variables grows.

In the paper's experiments, RMC achieves near-perfect accuracy on program evaluation with up to 8 variables, while LSTM accuracy degrades sharply beyond 4. The gap widens further when programs are longer or require nested operations.

6. Language Modeling Results

RMC is evaluated on two standard language modeling benchmarks measured in bits per word (BPW) β€” lower is better.

ModelPenn Treebank (BPW)Wikitext-103 (BPW)
LSTM1.301.38
DNC1.271.35
RMC1.221.29

RMC achieves the best BPW on both datasets. Language modeling benefits from relational memory because coherent text involves tracking entities across long spans: pronouns reference earlier nouns, verb agreement depends on the subject introduced many tokens ago, and discourse coherence requires remembering what has been said. Each entity can occupy a distinct memory slot, and attention can retrieve the right slot when generating context-dependent words.

7. Comparison with LSTM and DNC

The three models represent a spectrum of memory architectures:

PropertyLSTMDNCRMC
Memory structure1 vectorExternal matrix + controllerM ∈ R^{hΓ—d} slots
Read mechanismImplicit via gatingContent / location addressingMulti-head self-attention
Memory-to-memory interactionNoneNone (reads are independent)Full pairwise attention
Write mechanismLSTM gatesErase + add vectorsLSTM gates per slot
Relational task accuracyLowMediumHigh

DNC (Differentiable Neural Computer, Graves et al. 2016) also uses an external memory matrix, but memory locations are read and written independently using content-based and location-based addressing. There is no mechanism for memories to interact with each other β€” slot i cannot directly query slot j. RMC closes this gap by running full self-attention across all slots before the write step.

8. Connection to Transformers

RMC and the Transformer share the same multi-head self-attention formula. The key architectural difference is the axis of attention:

  • Transformer: Attention runs across token positions in the sequence. At each layer, every token position attends to all other positions. The sequence length is fixed within a forward pass.
  • RMC: Attention runs across memory slots at each time step. The number of slots h is fixed, but the model processes one time step at a time (like an RNN). Memory persists across time steps via the LSTM update.

This makes RMC a temporal model with relational capacity β€” it can handle variable-length sequences online (one token at a time), while maintaining structured relational memory that is updated and refined at each step. The Transformer, by contrast, processes entire sequences in parallel but lacks persistent state across calls.

In retrospect, RMC can be seen as a precursor to memory-augmented Transformers (like Transformer-XL and Memorizing Transformers). The shared insight: attention over a small, structured set of memory vectors is a powerful primitive for relational computation, whether those vectors are token representations or persistent memory slots.