The argument for a Transformer
At five posts into this series, the station model's ceiling had emerged around 70% greedy, 81% beam, and the stratified breakdown by journey length told the whole story. We were getting 90%+ on short routes, 48% on long, 30% on very long journeys. The wall from autoregressive error compounding was structural and well-characterised.
However the compounding analysis pointed at the decoder, not the graph encoder.
The autoregressive failure mode can be expressed simply: if the model picks the correct next station with probability , then the probability of getting an entire route of length correct is . Even good per-step accuracy degrades fast when raised to the power of 20–30 decisions.
The GATv2 encoder had already shown that it understood the network topology.
The adjacency masking experiments from part 4 were the clearest evidence: the mask is built directly from the topology the encoder produces (line_adj flattened across lines), and applying it jumped exact match from 44% to 69% greedy.
That only works if the encoder has already learned the correct adjacency structure.
Additionally, every model from part 2 onward achieved 100% topological validity — the encoder never produces embeddings that lead to invalid station IDs.
The weakness appeared when the model had to decode those embeddings into a sequential route.
The GRU decoder maintains a single hidden state vector — 128 dimensions at dev scale, 512 at full. At step 30 of a 30-station journey, everything the model knows about the origin, the destination, the line it's on, the direction of travel, which interchanges it has already passed, and what the remaining plan is, must be compressed into that one vector. There is no direct access to the origin, the destination, or the route history — all of it is propagated through 30 recurrent updates.
Even with the adjacency mask reducing each step to just 3–4 candidate stations, the GRU was only selecting the correct neighbour about 72% of the time (measured on the validation set in the part 4 full-profile run). That sounds decent until you consider that this is accuracy over 3–4 candidates, not over the full vocabulary of 272 — and that it means roughly one wrong turn in every three forks.
A Transformer decoder removes this compression bottleneck. Instead of propagating information through a single recurrent state, every position in the output sequence can directly attend to every other position through self-attention. At step 30 the model does not need to remember what happened at step 1 — it can simply attend to it.
Cross-attention to the encoder output provides a second advantage. The decoder can query the full set of 272 station embeddings at every step, meaning the destination and the global network structure are always directly accessible rather than being compressed into the recurrent state. In effect the decoder can recompute the route plan at each step using the full context of the journey so far.
With a vocabulary of only 272 stations, sequences of at most 50 tokens, and a branching factor of just 3–4 neighbours after adjacency masking, the routing task is tiny by modern sequence-model standards. Transformer decoders routinely solve problems orders of magnitude larger with near-perfect token accuracy.
The Transformer decoder
The implementation was a drop-in replacement: nn.TransformerDecoder with 4 layers (a reasonable default for a problem this small — no ablation on layer count), cross-attention to the GATv2 encoder output, and norm_first=True (pre-norm, where layer normalisation is applied before the self-attention and cross-attention sublayers rather than after, giving more stable gradients through the residual path during training from scratch).
The input at each decoding position is the sum of three components: a station embedding for the current token, a positional embedding for the step index, and a destination projection.
The destination projection works as a constant bias: the destination station's encoder embedding passes through a learned linear layer (dest_proj, no bias term) and the result is added to every position's input.
So the full input at position is station_emb(token_t) + pos_emb(t) + dest_proj(H[dest]).
The destination signal is identical at every position, telling the decoder where it's trying to get to.
The adjacency mask applies identically to the output logits.
The GRU decoder had used a pointer mechanism to make its station predictions: at each step, the hidden state passed through a learned query projection (query_proj, a linear layer ), and all encoder station embeddings passed through a learned key projection (key_proj, another linear ).
The logits were the dot product q @ keys.T — one score per station, 272 in total.
This is essentially single-head attention with no value projection and no feedback into the decoder state.
It is purely a scoring mechanism: "how much does my current state match each station?"
In the Transformer version this indirection is unnecessary. Cross-attention already exposes the entire encoder representation to the decoder at every step — not just as a scoring mechanism, but with proper value projections so the decoder actually reads information from the encoder and feeds it back into the representation via a residual connection. A simple linear projection from the decoder state to the 272-station vocabulary is sufficient.
The first run hit a CUDA assertion from torch.compile in reduce-overhead mode: the positional embedding only had max_len entries but some label sequences were longer.
The GRU naturally stopped at max_len steps; the Transformer tried to embed every label position.
Truncating labels to max_len and switching torch.compile to default mode fixed both issues.
Dev result: promising
At d_model=128 with 20 epochs — the same dev profile as every previous experiment:
| Greedy | Beam | Per-token | |
|---|---|---|---|
| GRU | 18.2% | 28.2% | 72% |
| Transformer | 43.0% | 43.0% | 95.0% |
The per-token accuracy jumped from 72% to 95%. The greedy exact match more than doubled. The Transformer was doing exactly what was predicted: with direct access to the full route history and destination at every step, it made far better per-step decisions.
But the beam column told a different story. 43.0% greedy, 43.0% beam. Zero gap.
The distribution collapse
The GRU's 10-point beam gap (18.2% → 28.2%) meant it had learned multiple valid routes per OD pair and spread probability across them. The Transformer concentrated all probability mass on a single path. Beam search found nothing to explore. The model had become extremely confident about a single route, even when several were equally valid.
This was a direct consequence of how the two decoders are trained.
The GRU uses scheduled sampling at : half the time during training, it sees its own predictions rather than the ground truth. This forces it to learn recovery from errors and, as a side effect, prevents it from collapsing onto a single route per OD pair. To see why, consider an OD pair like Camden Town → Victoria, which has three valid routes. In epoch 1 the sampled label might be "Northern southbound to Embankment, District westbound." In epoch 2 it might be "Northern southbound to Kennington, change to somewhere else." The GRU's hidden state is updated by backprop in both epochs, but with scheduled sampling, half the training steps saw the model's own noisy predictions rather than the clean label. The hidden state can't fully memorise route A because it was trained on corrupted versions of route A. When route B arrives next epoch, the state doesn't cleanly overwrite — it blurs. The result is a distribution that assigns nonzero probability to both routes. This is exactly what beam search exploits.
The Transformer uses parallel teacher forcing: the full ground-truth label sequence is shifted right by one position (prepend the origin, drop the last token) to create the decoder input, and a causal mask ensures position can only attend to positions . All positions are computed in a single forward pass — one matrix multiply through each sublayer per layer. The model sees the complete, clean label at every position. It memorises route A perfectly. When route B appears in the next epoch, it memorises that instead. The distribution is always peaked on whichever route was seen most recently. There's no blurring because there's no noise and no compression bottleneck.
At inference, the first wrong prediction sends the model into territory it has never seen. Say the correct route is King's Cross → Euston → Warren Street → Oxford Circus (southbound Northern, then Victoria line). The model predicts King's Cross → Euston correctly (easy — only one southbound neighbour). At step 3 it predicts Mornington Crescent instead of Warren Street — both are adjacent to Euston, both valid under the adjacency mask. Now the input sequence is [King's Cross, Euston, Mornington Crescent]. This exact prefix never appeared in any training label. The self-attention at step 4 attends to Mornington Crescent and has no learned representation for recovering from this mistake. The next prediction is essentially random among Mornington Crescent's neighbours, and the route diverges permanently.
The GRU in the same situation: it predicts Mornington Crescent, the hidden state is updated, and because it trained with scheduled sampling it has seen similar "wrong station in the hidden state" situations thousands of times. It learned to get back on track — maybe it predicts Camden Town next (Mornington Crescent's other neighbour), then continues north. The route is wrong but coherent. Meanwhile beam search may have a parallel hypothesis that took Warren Street at step 3 and is still on the correct route.
96% per-token accuracy with zero diversity is worse than 72% per-token accuracy with a healthy beam gap. The GRU at 81% beam was beating the Transformer at 43% because beam search was doing real work — and beam search needs entropy to work.
Four failed fixes
Scheduled sampling (step-by-step)
The obvious port: at each step, with probability 0.5, feed the model's own prediction instead of the teacher token.
This required abandoning the Transformer's parallel forward pass. Under teacher forcing the entire sequence can be processed simultaneously because every input token is known in advance. Scheduled sampling breaks that assumption: the next input token depends on the model's prediction at the previous step.
The decoder therefore had to run step-by-step. At each step the entire self-attention stack was recomputed over the growing prefix. For a sequence of length , this turns a single forward pass into passes over prefixes of length , giving a total cost proportional to . For that's roughly 800 attention computations versus 40 in the parallel case — and the step-by-step loop also prevents any GPU parallelism across positions, so the constant factor is much worse too. Training time increased by roughly an order of magnitude.
The model learned essentially nothing: 2.3% exact match after 20 epochs.
The failure mode was architectural. In a GRU, the hidden state is the only representation of the sequence history. If a corrupted token is fed in at step 15, the hidden state at step 16 is directly computed from that corruption, and the corrupted state propagates forward through every subsequent step. The model has no choice but to learn to cope.
In a Transformer, each position can attend to all previous positions independently. If one token in the sequence is corrupted, later positions can attend more strongly to the uncorrupted tokens and largely ignore the noise. Step 16 can attend directly to the clean tokens at steps 1–14 and downweight the corrupted step 15. The model never needed to learn recovery, so it didn't.
Token corruption (parallel)
An attempt to get the "exposure to errors" benefit without losing parallelism. During teacher-forced training, randomly replace input tokens with adjacent stations (plausible errors). The forward pass stays parallel.
The first attempt had a masking bug: the adjacency mask was being computed from the corrupted token positions (self._apply_adj_mask(logits, dec_input) where dec_input contained the corrupted tokens).
If the clean teacher token at position is Warren Street but it got corrupted to Mornington Crescent, the adjacency mask at position shows Mornington Crescent's neighbours.
Warren Street might not be adjacent to Mornington Crescent, so the correct label for position gets masked out with logit .
Cross-entropy on a target with logit produces enormous loss.
The fix: keep a dec_input_clean tensor and always compute the adjacency mask from that, so the model sees corrupted inputs but the loss is always computed against reachable targets.
After fixing the mask, the model trained at full speed but showed identical results to pure teacher forcing (42.9% vs 43.0% greedy, both zero beam gap). If the corruption were actually affecting the model's learning, you'd expect some measurable change — either improved diversity or degraded accuracy. The fact that it had zero effect is strong evidence that the Transformer's self-attention was attending around the noisy tokens just as easily as it attended around the step-by-step noise.
Temperature at inference
If the distribution is too sharp, soften it at decode time: divide logits by before softmax. This required no retraining — just a one-line change in beam search.
Zero effect. The logit gap between the model's top choice and the alternatives was so large that even at temperature 1.5, beam search found nothing.
Min-loss over valid routes
Train against all valid routes simultaneously: for each example, compute cross-entropy against every valid label and backpropagate through the minimum. The intuition is that the model should never be punished for predicting any valid route — you pick the label it's closest to and only penalise the gap to that one.
A typical OD pair has 2–3 valid routes (average 2.6 across 73,712 pairs), generated by BFS with up to 2 transfers, sorted by travel time, top 3 stored.
get_all_labels_batch retrieves all stored routes for a batch of OD indices — it's a lookup, not a runtime search.
This was previously only used for evaluation; now it ran in the training loop, computing a per-example loop over the 2–3 valid labels.
The result: 44.7% greedy, 44.7% beam. Marginally better exact match, still zero beam gap.
The Transformer simply found whichever of the 2–3 routes was easiest to predict from the encoder embeddings and committed to it completely. Min-loss was never going to fix this: the objective permits distributing probability across routes but doesn't encourage it. The loss landscape has multiple global minima (one per valid route) and the Transformer, being high-capacity with perfect teacher forcing, converges cleanly to one of them. The GRU distributes probability because scheduled sampling forces it to, not because the loss function asks for it.
Why Transformer decoders fail on this problem
The four failures share a common root cause.
Transformer self-attention excels at tasks where there is a single correct sequence to predict. Because every position can attend to every other position, the model can maintain extremely coherent representations of long sequences. This is precisely why Transformers replaced recurrent networks in machine translation and language modelling.
Route planning, however, is not a single-target task. Many origin–destination pairs have several equally valid routes. A useful decoder must therefore represent a distribution over plausible routes rather than collapsing onto a single deterministic plan.
The Transformer's strength — its ability to maintain perfectly consistent representations — becomes a liability in this setting. Once it commits to a particular route, the architecture reinforces that commitment at every step through self-attention. With 96% per-token accuracy the model has effectively decided there is only one correct path. The output distribution collapses onto a single trajectory, leaving beam search with nothing to explore.
The GRU's hidden state bottleneck is what keeps its distribution soft. It can't perfectly remember which of the 2–3 valid routes it saw most recently — the lossy compression and the noise from scheduled sampling mean that exposure to route B doesn't fully erase the memory of route A. The hedging looks like imprecision — 72% per-token accuracy — but it's actually the model maintaining a distribution over alternatives. Beam search converts that distribution into concrete route predictions.
The full-profile run confirmed the pattern: 54.8% greedy, 54.8% beam with zero gap, followed by collapse to near-zero by epoch 200 as the cosine schedule drove the learning rate down and the model overfit its single-route predictions.
The hybrid decoder
The diagnosis pointed to a clear design: keep the GRU’s sequential generation and diversity, but give it direct access to the encoder through attention. The GRU's sequential generation with scheduled sampling is necessary for distribution diversity — without it, the model collapses onto a single route. But the pointer mechanism was the bottleneck limiting per-step accuracy — a single-head scoring mechanism with no information flowing back into the representation. The natural move is to keep the GRU loop but replace the pointer with something that has direct encoder access.
The hybrid decoder is a GRU with one cross-attention layer. At each step:
- The GRU cell updates its hidden state from the previous prediction's embedding (exactly as before —
start_input,station_emb,sampling_p, all identical). - The hidden state queries the full encoder output via multi-head cross-attention (8 heads; 16 dims per head at dev, 64 at full).
- The cross-attended representation is residual-connected and layer-normed with the hidden state.
- The result projects to station logits via a linear layer.
- The adjacency mask applies identically.
The GRU loop is unchanged: scheduled sampling works, teacher-tracked masking works, the existing beam search structure works.
The cross-attention replaces the pointer mechanism's single query–key dot product with a richer interaction. Instead of producing one similarity score per station, the decoder performs multi-head attention over all 272 encoder station embeddings. Each head can focus on different structural features of the graph, and — unlike the pointer — the value projections mean the decoder actually reads information from the encoder and incorporates it via the residual connection before making the routing decision. The pointer asked "which station matches my current state?" The cross-attention asks "which station matches, and what does it tell me?" — then uses that answer to predict.
For beam search, the encoder output is computed once and shared across all beams.
Each beam candidate maintains its own GRU hidden state.
At each step, all candidates' GRU states are updated in parallel (batched GRU cell), then each candidate's hidden state independently queries the shared encoder output via batched cross-attention, then logits are computed and adjacency-masked per candidate, and standard top- pruning selects the surviving beams.
The beam dispatcher checks isinstance(dec, HybridStationDecoder) and calls the cross-attention inside the per-step loop, whereas the GRU dispatcher calls the pointer mechanism's key_proj/query_proj instead.
The total additional parameters over the base GRU are modest: the cross-attention weights (roughly for ///output projections) plus a layer norm (). At that's about 66K extra parameters; at 512 it's about 1M. The full hybrid comes to 6.6M parameters — far less than the pure Transformer's 20.6M, since there's only one cross-attention layer instead of 4 full Transformer decoder layers.
Dev result
| Greedy | Beam | Gap | |
|---|---|---|---|
| GRU (pointer) | 18.2% | 28.2% | +10.0 |
| Transformer | 43.0% | 43.0% | +0.0 |
| Hybrid | 32.3% | 57.4% | +25.1 |
The beam gap returned — 25 points, two and a half times the GRU's. The model learned genuine route diversity (the GRU backbone with scheduled sampling) and better per-step decisions (cross-attention to the encoder).
The greedy number (32.3%) sits between the GRU (18.2%) and the Transformer (43.0%). This makes sense: cross-attention improves per-step accuracy over the pointer mechanism, but scheduled sampling costs some per-step precision in exchange for diversity. The beam number (57.4%) exceeds both — confirming that the diversity isn't noise, it's the model expressing knowledge of multiple valid routes.
The stratified breakdown at dev scale:
| Route length | Beam accuracy | Count |
|---|---|---|
| 2–5 stations | 55% | 51 |
| 6–10 | 51% | 139 |
| 11–20 | 30% | 352 |
| 21–30 | 19% | 166 |
| 31–50 | 7% | 29 |
At dev scale and 20 epochs this is undertrained. In earlier experiments the GRU's beam accuracy increased by a factor of roughly 2.9 between the dev profile and the full training profile (28.2% → 81.0%). If the hybrid scales similarly from its dev result of 57.4%, the full-profile run would land somewhere in the 85–95% range.
There's reason to think the hybrid might scale better than the GRU. The GRU pointer's scaling bottleneck was the hidden state itself — more helps but with diminishing returns because the fundamental compression problem remains. The hybrid's cross-attention benefits directly from more (more heads, richer key/value representations) and from more encoder layers (better station embeddings to attend to). Counter-argument: the GRU backbone still limits sequential coherence, so at some point it becomes the bottleneck again regardless of how good the cross-attention is.
The full-profile run (200 epochs, cosine LR with warmup, , 6 encoder layers, beam width 20) is currently training.
What the Transformer experiment taught
The negative results were as informative as the positive one.
Per-token accuracy is not the bottleneck. The Transformer proved this conclusively: 96% per-token accuracy produced worse routes than 72% per-token accuracy. What matters is the distribution over routes, not the precision on any single route.
Scheduled sampling serves two purposes, not one. The obvious purpose is robustness to the model's own errors at inference. The subtler purpose — and in this setting the more important one — is preventing distribution collapse. The GRU can't overwrite its previous beliefs because noise constantly disrupts its hidden state. The Transformer can and does.
Self-attention routes around noise. Token corruption, step-by-step sampling, and min-loss all failed because the Transformer's ability to attend past corrupted tokens means it never needs to learn robustness. In a GRU, you can't attend past a corrupted hidden state — it's the only state there is. That constraint is what makes scheduled sampling work.
The pointer mechanism was a genuine bottleneck. Replacing it with cross-attention improved the beam result from 28.2% to 57.4% at the same scale. The GRU's hidden state was always the wrong place to put routing decisions — it just happened to be the only place the pointer mechanism could look.
Full-profile result
The full-profile run completed in 11 hours over 200 epochs, with , 6 encoder layers, batch size 256, and beam width 20.
| Greedy | Beam | Gap | |
|---|---|---|---|
| GRU (pointer) | 70.1% | 79.1% | +9.0 |
| Transformer | 54.8% | 54.8% | +0.0 |
| Hybrid | 73.7% | 86.0% | +12.3 |
New best model across every metric.
The stratified breakdown shows where the gains concentrate:
| Route length | GRU beam | Hybrid beam | Δ |
|---|---|---|---|
| 2–5 stations | 88–93% | 88% | ~0 |
| 6–10 | 88–93% | 94% | +3 |
| 11–20 | 74% | 78% | +4 |
| 21–30 | 48% | 56% | +8 |
| 31–50 | 30% | 38% | +8 |
Short routes (≤10 stations) are comparable — the adjacency mask does most of the work at that length and there isn't much room to improve. The gains concentrate on medium and long routes, exactly where the GRU's hidden state bottleneck hurt most. The 21–30 station bracket jumped 8 points; the 31–50 bracket jumped 8 points. Cross-attention to the encoder is letting the model maintain its plan over longer sequences without the plan degrading through the recurrent state.
The finer-grained per-station breakdown reveals the characteristic shape of autoregressive compounding under this architecture:
| Length | Beam EM | n |
|---|---|---|
| 6 | 100% | 27 |
| 7 | 95% | 44 |
| 8 | 91% | 43 |
| 9 | 93% | 40 |
| 10 | 95% | 75 |
| 11 | 88% | 73 |
| 12 | 90% | 70 |
| 13 | 84% | 88 |
| 14 | 83% | 82 |
| 15 | 80% | 80 |
| 16 | 78% | 90 |
| 17 | 67% | 61 |
| 18 | 74% | 73 |
| 19 | 69% | 65 |
| 20 | 56% | 50 |
Routes up to about 12 stations are above 88%. Between 13 and 16 stations accuracy degrades roughly linearly, consistent with compounding error at a low per-step rate. The jump down at 17 stations and again at 20 stations likely reflects specific junction-heavy OD pairs — probably Northern line forks or the District/Circle/Hammersmith & City shared section — where the routing decision is genuinely ambiguous.
Convergence and the remaining ceiling
The training curves tell a nuanced story.
Validation loss dropped quickly to roughly 0.12 by epoch 42, then plateaued for the remaining 158 epochs. The model essentially stopped improving in terms of per-token cross-entropy around epoch 40–50.
Greedy exact match, however, kept climbing through the plateau — from roughly 40% at epoch 50 to 73.7% at epoch 200. The divergence makes sense for a structured prediction task. A tiny improvement at a critical junction — the Northern line fork at Camden Town, the District/Piccadilly choice at Earl's Court — can flip an entire downstream route from wrong to right. Exact match is sensitive to these marginal gains in ways that average per-token loss is not.
The beam gap (12.3 points) is wider than the GRU's (9.0 points), confirming that the hybrid maintains healthy route diversity. Beam search is doing real work: for 12% of validation journeys, the correct route exists somewhere in the top 20 hypotheses even though it isn't the greedy prediction.
The training cost was substantially higher than previous experiments. The GRU pointer decoder's full-profile runs completed in under an hour. The hybrid took 11 hours, because the cross-attention over 272 encoder stations at every decoding step, combined with the smaller batch size forced by memory constraints (256 vs the GRU's 2048), meant many more gradient steps per epoch with more compute per step.
Whether there are further gains from extended training (say 500 epochs) or architectural changes (stacking multiple cross-attention layers, adding travel time as an encoder edge feature) remains to be tested. But the core finding stands: a hybrid GRU–attention decoder that combines the GRU's sequential dynamics with direct encoder access outperforms both a pure GRU and a pure Transformer on autoregressive route prediction over a transport graph.
- Note: The code for this PR is at lmmx/tubeulator-models#5.