A Simple Neural Network Module for Relational Reasoning

Santoro, Raposo, Barrett, Malinowski, Pascanu, Battaglia, Lillicrap (DeepMind) Β· NeurIPS 2017 Β· arXiv 1706.01427

TL;DR

The Relation Network (RN) is a plug-in neural network module that explicitly reasons about pairwise relations between objects. Given a set of objects, it computes a learned relation embedding for every pair, sums them all, and passes the result through an MLP to produce an answer. This simple inductive bias β€” consider all pairs β€” is surprisingly powerful: RNs achieve 95.5% on CLEVR visual QA (surpassing human performance of 92.6%) and near-perfect on bAbI language tasks.

1. What Is Relational Reasoning?

Many questions we ask about the world are inherently relational. "Is the red cube to the left of the blue sphere?" cannot be answered by inspecting either object in isolation β€” you must compare them. "Which of these two shapes is larger?" requires placing both objects in relation to each other. Classic deep learning models (MLPs, CNNs) do not have a built-in mechanism to reason about such pairwise relationships.

Relational reasoning is the ability to explicitly reason about relationships and properties of entities and their interactions. It underlies tasks like visual question answering, reading comprehension, knowledge graph reasoning, and physical simulation.

2. The Relation Network Module

The Relation Network is defined by a single equation. Given a set of objects O={o1,o2,…,on}O = \{o_1, o_2, \ldots, o_n\}, the RN computes:

RN(O)=fφ ⁣(βˆ‘i,jgΞΈ(oi,oj))\text{RN}(O) = f_\varphi\!\left(\sum_{i,j} g_\theta(o_i, o_j)\right)

There are only two learned functions:

  • gΞΈg_\theta β€” a small MLP that takes a pair of objects as input and outputs a relation embedding: a vector summarizing whatever relationship might exist between those two objects. The same MLP is applied to every pair (weights are shared across all pairs).
  • fΟ†f_\varphi β€” another MLP that takes the sum of all relation embeddings and produces the final output (e.g., an answer to a question). The sum aggregates evidence from all pairwise comparisons into a single vector.

The key design choices are: (1) sharing gΞΈg_\theta across all pairs (like a convolutional kernel shared across positions), and (2) using summation as the aggregation operator β€” this makes the module invariant to the ordering of objects, and it is differentiable everywhere.

3. All Pairwise Relations

A crucial design decision is to consider every pair (oi,oj)(o_i, o_j) β€” including pairs where i=ji = j (self-pairs) and both orderings (oi,oj)(o_i, o_j) and (oj,oi)(o_j, o_i). With nn objects, this yields n2n^2 pairs total β€” an O(n2)O(n^2) computation.

This exhaustive enumeration is intentional. The network does not need to know in advance which pairs are relevant β€” it considers all of them and lets gΞΈg_\theta learn to output a near-zero vector for irrelevant pairs and a meaningful embedding for relevant ones. The sum then focuses on the relevant pairs because they dominate the signal.

n (objects)Pairs consideredExample
39Small scene
10100CLEVR scene (avg ~10 objects)
25625CNN feature map (5Γ—5 spatial positions)

For visual inputs, the objects are typically CNN feature map cells β€” a 5Γ—5 feature map gives 25 objects and 625 pairs. This is manageable. For very large inputs the quadratic cost can become expensive, but the RN paper's tasks stay in this tractable regime.

4. Processing Visual Inputs

For visual question answering, the full pipeline is:

  1. CNN encoder: A convolutional network processes the input image and produces a feature map of shape kΓ—kΓ—dk \times k \times d. Each of the k2k^2 spatial positions becomes one object: oi∈Rdo_i \in \mathbb{R}^d. The 2D coordinates of each cell are appended to oio_i so the network knows where each object is in the image.
  2. LSTM question encoder: The question is encoded by an LSTM into a fixed-length question embedding q∈Rdqq \in \mathbb{R}^{d_q}. This embedding captures the semantics of the question β€” "what color is...", "how many...", "is there a...", etc.
  3. Relation Network: The question embedding qq is appended to every pair input, so gΞΈg_\theta receives concat(oi, oj, q)\text{concat}(o_i,\, o_j,\, q). This conditions the relational reasoning on the question β€” for "what color?" the network learns to focus on color attributes; for "to the left of?" it focuses on spatial relations.

The full VQA formula is:

a=fφ ⁣(βˆ‘i,jgΞΈ(oi, oj, q))a = f_\varphi\!\left(\sum_{i,j} g_\theta(o_i,\, o_j,\, q)\right)

where the input to gΞΈg_\theta is a concatenated vector of dimension 2d+dq2d + d_q:

concat(oi,oj,q)∈R2d+dq\text{concat}(o_i, o_j, q) \in \mathbb{R}^{2d + d_q}

5. CLEVR: A Hard Relational Reasoning Benchmark

CLEVR (Johnson et al., 2017) is a synthetic visual QA dataset designed to test compositional and relational reasoning. Scenes contain 3D-rendered objects with attributes like shape (cube/sphere/cylinder), color (8 colors), material (rubber/metal), and size (large/small). Questions are generated from programs and require multi-step reasoning.

Example questions:

  • "What color is the cube that is to the left of the cylinder?"
  • "Are there more red things than blue things?"
  • "What size is the object made of the same material as the large blue cube?"

These questions require comparing multiple objects, tracking references across clauses, and reasoning about spatial relationships β€” exactly the kind of task that benefits from explicit relational reasoning.

6. Worked Example: Three Objects, One Question

Let's trace through the Relation Network for a simple 3-object scene.

Scene: Three objects: a large red cube (o₁), a small blue sphere (oβ‚‚), a large metal cylinder (o₃).

Question: "What color is the cube that is to the left of the cylinder?"

Step 1 β€” Extract objects: The CNN produces feature vectors for each spatial cell. Suppose o1,o2,o3∈Rdo_1, o_2, o_3 \in \mathbb{R}^d are extracted, each with appended 2D coordinates.

Step 2 β€” Encode question: The LSTM produces q∈R128q \in \mathbb{R}^{128}, encoding "what color ... cube ... left of ... cylinder".

Step 3 β€” Apply g to all 9 pairs:

PairInput to gRelevance
(o\u2081, o\u2081)[o1,o1,q][o_1, o_1, q]Low (self-pair, cube vs cube)
(o\u2081, o\u2082)[o1,o2,q][o_1, o_2, q]Low (cube vs sphere β€” wrong pair)
(o\u2081, o\u2083)[o1,o3,q][o_1, o_3, q]High (cube to the left of cylinder!)
… 6 more pairs …\u2026Low (irrelevant combinations)

Step 4 β€” Sum and aggregate: The 9 embeddings are summed. The (o₁, o₃) pair dominates the sum since gΞΈg_\theta learned to output a large signal for it. Then fΟ†f_\varphi maps this aggregate to the answer distribution, outputting high probability for "red".

7. Results

The Relation Network achieved state-of-the-art results across multiple benchmarks in 2017:

BenchmarkRN AccuracyBaseline / Human
CLEVR (overall)95.5%Human: 92.6%
Sort-of-CLEVR (relational Qs)96%CNN+MLP baseline: 63%
bAbI (passed tasks)18 / 20Prior LSTM: 12 / 20

The 33-percentage-point gap on Sort-of-CLEVR relational questions (96% vs 63%) is particularly striking. The task is designed so that non-relational questions (e.g., "What shape is the red object?") can be answered by inspecting one object, while relational questions (e.g., "What shape is the object to the left of the green object?") require comparing two objects. The CNN+MLP baseline, lacking explicit relational structure, performs near chance on the relational subset.

8. Connection to Attention and Transformers

The Relation Network has a deep conceptual connection to self-attention, which is the core of the Transformer architecture (Vaswani et al., 2017 β€” published the same year).

Recall the self-attention formula:

Attention(Q,K,V)=softmax ⁣(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V

Self-attention also considers all pairs of positions: the logit QiKjT/dkQ_i K_j^T / \sqrt{d_k} measures the relationship between position ii and jj, and the output is a weighted sum over all positions. The key differences:

AspectRelation NetworkSelf-Attention
Pair functionNonlinear MLPDot product (linear)
AggregationUnweighted sumSoftmax-weighted sum (attention)
Conditioning on queryExplicit (q appended to input)Via learned Q projection
Computational complexityO(n2)O(n^2)O(n2)O(n^2)

In a sense, the Relation Network can be viewed as a form of unweighted self-attention where the attention score is a nonlinear function of both objects, and the aggregated value is always the pair embedding itself (rather than a value vector). Self-attention can be seen as a more structured, computationally efficient version that uses dot-product similarity to weight contributions.

The Relation Network's contribution is its clarity as an inductive bias: it says explicitly that reasoning should be done by comparing all pairs of objects. This interpretable design principle β€” rather than emerging implicitly from gradient descent β€” is what makes the paper influential. The Transformer generalizes this to sequences, stacks the operation multiple times, and adds multi-head projections, but the core idea of all-pairs comparison is shared.