Order Matters: Sequence to Sequence for Sets

Vinyals, Bengio, Kudlur Β· ICLR 2016 Β· arXiv 1511.06391

TL;DR

Seq2seq models are sensitive to the order of both input and output elements β€” even when the underlying data is a set with no natural order. This paper studies how different orderings affect perplexity on tasks like sorting, proposes attention-based (Pointer Network) architectures to handle unordered inputs, and introduces a Read-Process-Write memory architecture for truly order-invariant set encoding. The key lesson: model structure should match problem structure.

1. Order Sensitivity in Seq2Seq

Standard seq2seq models β€” encoder RNN + decoder RNN β€” factorize the joint distribution over an output sequence as a product of conditionals:

p(y∣x)=∏t=1Tp(yt∣y1,…,ytβˆ’1, x)p(y \mid x) = \prod_{t=1}^{T} p(y_t \mid y_1, \ldots, y_{t-1},\, x)

Each conditional depends on all previous outputs and the entire input encoding. This is well-suited for language, where there is a natural left-to-right order. But what happens when the input or output is a set β€” a collection of elements with no inherent ordering?

The paper demonstrates that the choice of ordering is not neutral: two equivalent representations of the same set (e.g., [3, 1, 4, 2] vs. [1, 2, 3, 4]) can yield dramatically different training perplexities, converge at different speeds, and generalize differently. The model architecture encodes strong inductive biases about sequence structure, which conflict with the set's symmetry.

2. Input Order Matters: The Sorting Experiment

The authors design a controlled experiment: given a set of numbers, train a seq2seq model to output them in sorted order. The input is the same set, but presented in different orderings to the encoder. The decoder always produces numbers in ascending order.

They test several input orderings and measure perplexity on held-out sets:

  • Random order: Elements shuffled arbitrarily β€” high perplexity, slow convergence.
  • Sorted ascending: Input already sorted β€” model essentially learns identity; low perplexity.
  • Sorted descending: Input sorted in reverse β€” also low perplexity; a consistent pattern still helps.
  • By frequency (optimal for discrete): Rare or high-value items first β€” often the best heuristic for discrete vocabularies.

The key finding: when the input order is correlated with the output order (or is a consistent function thereof), the model learns faster and achieves lower perplexity. The encoder can build a useful hidden state, and the decoder can decode greedily without backtracking over ambiguous encodings.

3. Output Order Matters

The same sensitivity applies to the output side. When the decoder generates a set (e.g., all detected objects in an image, all answers to a math problem), there are multiple valid orderings. But training the model with inconsistent orderings creates an impossible optimization target: the model must simultaneously learn to produce element A before B, and B before A, depending on the training example.

The paper advocates for choosing a canonical output ordering β€” a deterministic function of the outputs themselves, not the inputs. For numerical outputs, sorting by value is natural. For object detection, sorting by confidence score or spatial position are reasonable choices. The key constraint is consistency: given any set of outputs, the ordering rule must always produce the same sequence.

4. Read-Process-Write Architecture

For truly unordered sets, the paper proposes the Read-Process-Write (RPW) architecture β€” a memory-augmented model with three components that interact with a set of memory cells:

Write Phase β€” Encode the Set

Each input element x_i is written into a memory cell m_i via a learned encoding function:

mi←mi+fwrite(xi)m_i \leftarrow m_i + f_{\text{write}}(x_i)

Since elements are written independently and in any order, the memory is permutation-equivariant: the same set always produces the same memory state (up to permutation of cells).

Process Phase β€” Attend and Aggregate

An LSTM runs for P steps with no inputs β€” it reads from memory via attention at each step. At process step t, the query q_t is derived from the LSTM hidden state, and the read vector is a weighted sum over memory cells:

ati=softmax ⁣(score(qtβˆ’1, mi))a_{ti} = \text{softmax}\!\left(\text{score}(q_{t-1},\, m_i)\right)
qt=βˆ‘iati miq_t = \sum_{i} a_{ti}\, m_i

The LSTM updates its hidden state using q_t. After P steps, the LSTM hidden state is a rich, order-invariant summary of the entire input set β€” the order of writing did not affect the final summary.

Read Phase β€” Pointer Network Decoding

At decode step t, instead of sampling from a fixed vocabulary, the model uses a pointer β€” it points directly to one of the input memory cells. The pointer distribution is:

p(yt∣y<t,x)=softmax ⁣(score(st, mi))p(y_t \mid y_{<t}, x) = \text{softmax}\!\left(\text{score}(s_t,\, m_i)\right)

where s_t is the decoder LSTM state and m_i are memory cells. This means the output vocabulary is the input set itself β€” enabling the model to generalize to sets of arbitrary size without retraining.

5. Handling Truly Unordered Sets

For tasks where there is genuinely no meaningful ordering β€” such as counting, detecting all objects, or computing aggregate statistics β€” any fixed ordering introduces spurious inductive bias. The paper distinguishes three strategies for dealing with unordered inputs:

  1. Heuristic ordering: Sort inputs by value, magnitude, frequency, or spatial location before feeding to a standard RNN encoder. Works surprisingly well in practice when there is a natural ordering correlated with the task structure.
  2. Attention encoder (Pointer Networks): Run a bidirectional RNN over the input, then use attention at each decode step. The attention mechanism aggregates over all input positions, partially mitigating order sensitivity β€” the decoder can attend to any element regardless of position.
  3. Read-Process-Write: Full order invariance via independent write + iterative attention-based processing. The only truly permutation-invariant approach among those studied. Best performance on tasks where set structure is fundamental.

The attention-based encoder with seq2seq attention (middle approach) is related to the Pointer Networks paper (Vinyals et al. 2015). The standard attention formula used at decode step t is:

ct=βˆ‘iΞ±ti hi,Ξ±ti=softmax ⁣(score(st, hi))c_t = \sum_{i} \alpha_{ti}\, h_i, \quad \alpha_{ti} = \text{softmax}\!\left(\text{score}(s_t,\, h_i)\right)

where h_i are encoder hidden states, s_t is the decoder state, and c_t is the context vector used to condition the output distribution. The context vector integrates information from all encoder positions β€” a weaker form of order invariance compared to RPW.

6. Worked Example: Sorting 5 Numbers

Let's trace through how Read-Process-Write handles the task of sorting the set {3, 1, 4, 2, 5}. The output should be the sequence 1, 2, 3, 4, 5.

Step 1: Write

Five memory cells are initialized. Each number is embedded and written independently: m_1 ← f(3), m_2 ← f(1), m_3 ← f(4), m_4 ← f(2), m_5 ← f(5). The order of writing is irrelevant β€” swapping m_1 and m_2 produces the same memory state (up to cell permutation).

Step 2: Process (P = 3 steps)

The process LSTM attends over all 5 cells at each step. It learns to build a holistic understanding of the distribution β€” e.g., attending to the minimum and maximum to understand the range, attending to the median to understand the midpoint. After 3 attention steps, the LSTM hidden state encodes the statistical structure of the full set.

Step 3: Read (decode with pointers)

At decode step 1: the decoder attends over memory cells and points to m_2 (value 1) β€” the minimum. At step 2: points to m_4 (value 2). At step 3: points to m_1 (value 3). At step 4: points to m_3 (value 4). At step 5: points to m_5 (value 5). The model outputs elements by pointing to memory, so it can never generate values outside the input set.

7. Legacy: Influence on Set Transformer and Modern Architectures

"Order Matters" arrived just before the attention revolution. Its ideas seeded a rich line of work on set-structured learning:

  • Set Transformer (Lee et al., 2019): Replaces the RPW LSTM with self-attention layers (inducing points / attention bottleneck). Achieves O(n) rather than O(nΒ²) complexity through learned summary tokens. Directly generalizes the process phase of RPW using Transformer building blocks.
  • Deep Sets (Zaheer et al., 2017): Formalizes the theory of permutation-invariant functions: any such function can be decomposed as ρ(Ξ£_i Ο†(x_i)) for learned Ο† and ρ. The write phase of RPW (element-wise encoding + sum) is exactly this decomposition.
  • Object detection ordering: Modern detection models like DETR (2020) use a Transformer decoder that attends to all encoder features simultaneously. The output is a set of object predictions, and DETR uses Hungarian matching (a permutation-optimal assignment) between predictions and ground truth β€” a direct application of the insight that set outputs need set-appropriate losses.
  • Combinatorial optimization: Pointer Networks (the base decoder architecture here) became the foundation for learned combinatorial solvers: TSP, VRP, scheduling. The RPW extension showed that the encoder side also needs to be set-aware for truly unordered inputs.
  • The core lesson persists: Transformers with full self-attention are (approximately) position-invariant when no positional encodings are used, but typically have positional encodings added. The question of when to use positional information and when to be set-invariant remains active in architecture design for molecules, point clouds, graphs, and multimodal inputs.

Further Reading