<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en"><generator uri="https://jekyllrb.com/" version="4.4.1">Jekyll</generator><link href="https://www.iesl.cs.umass.edu/diffusion/feed.xml" rel="self" type="application/atom+xml"/><link href="https://www.iesl.cs.umass.edu/diffusion/" rel="alternate" type="text/html" hreflang="en"/><updated>2026-06-02T14:59:42+00:00</updated><id>https://www.iesl.cs.umass.edu/diffusion/feed.xml</id><title type="html">dIESL</title><subtitle>dIESL — firing on all tokens. A working group at the Information Extraction and Synthesis Lab (UMass Amherst) studying non-autoregressive language models that generate every token at once, rather than one at a time, left to right: masked diffusion, insertion, and edit-based generation. </subtitle><entry><title type="html">Learned Relay Representations</title><link href="https://www.iesl.cs.umass.edu/diffusion/blog/2026/relay/" rel="alternate" type="text/html" title="Learned Relay Representations"/><published>2026-05-13T00:00:00+00:00</published><updated>2026-05-13T00:00:00+00:00</updated><id>https://www.iesl.cs.umass.edu/diffusion/blog/2026/relay</id><content type="html" xml:base="https://www.iesl.cs.umass.edu/diffusion/blog/2026/relay/"><![CDATA[<p>This post is a companion walkthrough of our paper <strong>Learned Relay Representations for Forward-Thinking Discrete Diffusion Models</strong><d-cite key="rozonoyer2026relay"></d-cite>. We summarize the motivation, the method, and the empirical findings in less technical prose; the full derivations, algorithm boxes, and ablations live in the paper.</p> <p><strong>Code:</strong> <a href="https://github.com/jacopo-minniti/relay">github.com/jacopo-minniti/relay</a></p> <div class="relay-tldr"> <p><strong>TL;DR.</strong> Standard masked diffusion <strong>throws away</strong> the rich hidden state computed at every denoising step and restarts from the discrete tokens alone. <span class="sc">Relay</span> keeps that state alive: at each step, the model’s last-layer hidden vector is projected and added back as a second input channel at the <em>next</em> step. Train the relay end-to-end with <strong>two-step truncated BPTT</strong> and it learns <em>what to remember</em>. On Sudoku-Extreme, this shifts the accuracy–NFE frontier outward; on a 1.5B-parameter coding LLM (<span class="sc">Relay</span>-finetuned Fast-dLLM v2), it beats vanilla SFT on HumanEval / MBPP with <span class="kw">~30% fewer forward passes</span>.</p> </div> <h2 id="introduction">Introduction</h2> <p>Masked Diffusion Models (MDMs)<d-cite key="austin_structured_2021"></d-cite><d-cite key="campbell_continuous_time_2022"></d-cite><d-cite key="sahoo_simple_2024"></d-cite><d-cite key="shi_simplified_2024"></d-cite> generate discrete sequences by <strong>iterative denoising</strong>: starting from a fully masked canvas $\mathbf{x}_{t_0}=(\textsf{[M]},\dots,\textsf{[M]})$, each forward pass unmasks a fraction of the remaining positions until the sequence is complete. Internally, the Transformer computes a rich hidden state at <strong>every</strong> position — including positions that are still masked — but at the end of each step <strong>those hidden states are thrown away</strong>. The next step starts again from the partially unmasked sequence alone.</p> <p>We call this the <strong>hard reset problem</strong>. In standard MDMs, the only information that survives a denoising step is the discrete tokens just committed. The continuous, distributional reasoning the model just performed about every uncommitted position is wiped clean — and the next pass has to reconstruct it from scratch.</p> <p>This matters because <strong>recurrent computation</strong> — unrolling a fixed-parameter model across many steps — is precisely the structural property recent work has tied to improved reasoning<d-cite key="saunshi2025reasoning"></d-cite>. MDMs already perform many forward passes per generation; the hard reset is what prevents any of that compute from compounding.</p> <blockquote> <p><strong>The question this post is about:</strong> <em>How can the sequential unmasking structure of MDMs support recurrent computation that carries richer information across steps?</em></p> </blockquote> <p>Our answer is <strong>Learned Relay Representations</strong> (<strong><span class="sc">Relay</span></strong>), a method that makes discrete diffusion models <em>forward-thinking</em>: at each denoising step, alongside any newly unmasked tokens, the model <strong>carries its last-layer hidden state forward as a learned relay</strong>. The next forward pass gets direct access to the prior step’s continuous computation. Simply piping these states forward, however, does not by itself ensure that they encode something useful for what follows. <span class="sc">Relay</span> therefore trains the relay end-to-end with <strong>truncated backpropagation through time (BPTT)</strong>, shaping it to be maximally informative for the next few denoising steps.</p> <figure class="relay-hero-schematic"> <img src="/diffusion/assets/img/blog/relay/relay_schematic_two_panel.png" alt="Schematic of Relay over steps k and k+1: backbone f_theta, relay R_theta, hidden states h_k and h_{k+1}, progressive unmasking."/> <figcaption> <strong>Figure 1.</strong> Schematic of <span class="sc">Relay</span> over two consecutive inference steps (as in the paper). At each step <em>k</em>, the backbone <em>f<sub>θ</sub></em> consumes the sum of embedded tokens <em>emb(<strong>x</strong><sub>t<sub>k</sub></sub>)</em> and the projected relay state <em>R<sub>θ</sub>(<strong>h</strong><sub>k</sub>)</em>, producing a hidden state <strong>h</strong><sub>k+1</sub> that is both unembedded into logits for the cross-entropy loss and forwarded along the relay feedback path (highlighted in the paper figure) into the next step. Tokens are progressively unmasked between steps; <em><strong>h</strong></em> provides a continuous, differentiable channel for information that has not yet been committed to a discrete token. </figcaption> </figure> <h2 id="masked-diffusion-models-in-60-seconds">Masked Diffusion Models in 60 Seconds</h2> <p>Before introducing the relay, it helps to be precise about what a standard MDM does and doesn’t carry across steps. We follow the formulations in<d-cite key="sahoo_simple_2024"></d-cite><d-cite key="shi_simplified_2024"></d-cite>.</p> <p>Let $\mathcal{V}$ be a token vocabulary that includes a distinguished mask symbol $\textsf{[M]}$. Sequences live in $\mathcal{V}^L$. We write $\mathbf{x}_0 \in (\mathcal{V} \setminus {\textsf{[M]}})^L$ for a clean training sequence and $\mathbf{x}_t \in \mathcal{V}^L$ for its partially masked counterpart at noise level $t$.</p> <h3 id="training">Training</h3> <p>The noising process samples a time $t \in [0,1]$ and independently masks each position with probability $\alpha_t$, yielding $\mathbf{x}_t$. A neural network with backbone $f_\theta$, embedding $\textsf{emb}$, and unembedding $\textsf{unemb}$ parameterizes the per-position posterior</p> \[p_\theta^i(w \mid \mathbf{x}_t) = \frac{e^{\ell^i(w)}}{\sum_{w' \in \mathcal{V}} e^{\ell^i(w')}}, \qquad \ell^i(w) = \textsf{unemb}(f_\theta(\textsf{emb}(\mathbf{x}_t)))^i_w,\] <p>and is trained by minimizing the weighted cross-entropy</p> \[\mathcal{L}(\theta) = \mathbb{E}_{\mathbf{x}_0,\, t,\, \mathbf{x}_t} \!\left[ \frac{1}{t} \sum_{i:\, \mathbf{x}_t^i = \textsf{[M]}} -\log p_\theta^i\!\left(x_0^i \mid \mathbf{x}_t\right) \right].\] <p>Critically, <strong>only $\mathbf{x}_t$ is fed to the model</strong>: nothing about previous denoising steps is consumed.</p> <h3 id="inference">Inference</h3> <p>Generation proceeds along a decreasing time grid $1 = t_0 &gt; t_1 &gt; \dots &gt; t_K = 0$. At step $k$, the model computes logits $\boldsymbol{\ell}_k$ for the masked positions and an unmasking policy $u(\cdot \mid \boldsymbol{\ell}_k, \mathbf{x}_{t_k})$ commits a subset of positions, producing $\mathbf{x}_{t_{k+1}}$. Common choices include unmasking a fixed fraction per step<d-cite key="nie_llada_2025"></d-cite> or confidence-based parallel decoding<d-cite key="wu2025fastdllmv2efficientblockdiffusion"></d-cite>.</p> <h2 id="the-hard-reset-problem">The Hard Reset Problem</h2> <p>After each inference step, MDMs <strong>discard the entire computational state</strong> used to choose the newly revealed tokens. Step $k{+}1$ starts again from $\mathbf{x}_{t_{k+1}}$ alone. Standard MDM inference therefore treats every partially masked sequence as a <em>fresh</em> prediction problem rather than as a continuation of an ongoing computation. Because models can only perform a constant number of FLOPs per forward pass, this hard reset <strong>prevents the model from amortizing reasoning across steps</strong>. Related observations have been made under the names <em>information island</em><d-cite key="xia_metastate_2026"></d-cite> and <em>sampling wall</em><d-cite key="jo_loopholing_2025"></d-cite>.</p> <h2 id="augmented-state-trajectories">Augmented State Trajectories</h2> <p>Our fix is to introduce a <strong>continuous, differentiable channel</strong> alongside the discrete one. We augment the inference state from $\mathbf{x}_{t_k}$ to the pair</p> \[\mathbf{s}_{t_k} \;=\; \bigl(\mathbf{x}_{t_k},\, \mathbf{h}_{t_k}\bigr),\] <p>where $\mathbf{h}_{t_k}$ is the last-layer hidden state produced at step $k$. We can decompose the desired behavior into two primitives the model must learn:</p> <ol> <li><strong>Produce</strong> a useful relay $\mathbf{h}_{t_k}$ at step $k$.</li> <li><strong>Consume</strong> $\mathbf{h}_{t_k}$ at step $k{+}1$.</li> </ol> <p>Concretely (cf. Figure 1), at each step the backbone reads the sum of token embeddings and a projected relay,</p> \[\mathbf{h}_{k+1} \;=\; f_\theta\!\bigl(\,\textsf{emb}(\mathbf{x}_{t_k}) + R_\theta(\mathbf{h}_{t_k})\,\bigr), \qquad \boldsymbol{\ell}_k \;=\; \textsf{unemb}(\mathbf{h}_{k+1}),\] <p>where $R_\theta$ is a small relay module that lives in the same residual stream as the token embeddings.</p> <p>Two things are worth emphasizing:</p> <ul> <li>The relay $\mathbf{h}_{t_k}$ is an <strong>internal model state</strong>, not an externally observed conditioning variable. At inference, each step forwards the pair $(\mathbf{x}_{t_k}, \mathbf{h}_{t_k})$, but only $\mathbf{x}_{t_k}$ is ever decoded into text.</li> <li><span class="sc">Relay</span> leaves the inference-time decoding procedure (unmasking schedule, sampling rule) <strong>unchanged</strong>. The only addition at inference is shipping the relay alongside the committed tokens.</li> </ul> <h2 id="training-the-relay-with-truncated-bptt">Training the Relay with Truncated BPTT</h2> <p>A relay channel that is <em>passed forward</em> but never <em>trained to be useful</em> will not learn to carry the right information. So during training we <strong>unroll the augmented state for $K$ steps</strong>, accumulate the standard cross-entropy loss across the rollout, and <strong>backpropagate through the relay path</strong> via truncated BPTT:</p> \[\mathcal{L}_K(\theta;\, \mathbf{x}_0,\, \xi_{0:K-1}) \;=\; \sum_{k=0}^{K-1} L_k(\boldsymbol{\ell}_k,\, \mathbf{x}_0).\] <p>Each per-step loss $L_k$ is the cross-entropy on the masked positions of the corresponding rollout state. Importantly, <strong>we treat the discrete update $\mathbf{x}_{t_k} \mapsto \mathbf{x}_{t_{k+1}}$ as fixed</strong> after sampling — that is, we do not differentiate through the sampled unmasking decisions. Gradients flow into $\theta$ both <em>directly</em> through the logits at each step and <em>temporally</em> through the differentiable relay path $\mathbf{h}_k \to \mathbf{h}_{k+1}$.</p> <h3 id="training-procedure">Training procedure</h3> <p><strong>Algorithm 1</strong> in the paper instantiates one outer optimization step: maintain a rollout buffer state $(z, \mathbf{h})$, run $K$ inner relay steps with teacher-forced unmasking under policy $u$, accumulate cross-entropy, and apply a gradient step. The figure below matches the manuscript.</p> <figure class="relay-algorithm-figure" id="training-procedure-algorithm-1"> <img src="/diffusion/assets/img/blog/relay/relay_algorithm1.png" alt="Algorithm 1: Relay training — outer loop over training steps, inner K-step rollout with embedding, relay R_theta, forward f_theta, unembed, loss, unmasking, gradient update."/> <figcaption> <strong>Figure 2.</strong> <strong>Algorithm 1:</strong> <span class="sc">Relay</span> training (one outer iteration). Inner loop: update hidden state via <em>f<sub>θ</sub></em> on embedded tokens plus <em>R<sub>θ</sub>(<strong>h</strong>)</em>, compute logits and loss, sample unmasking positions, teacher-force commits from <strong>x</strong><sub>0</sub>, then apply <em>θ ← θ − η ∇<sub>θ</sub>L</em>. </figcaption> </figure> <h3 id="constructing-rollouts">Constructing rollouts</h3> <p>Given the current augmented state $(\mathbf{x}_{t_k}, \mathbf{h}_{t_k})$, one rollout step proceeds as:</p> <ul> <li><strong>Unmasking.</strong> Sample positions $\mathcal{U} \sim u(\cdot \mid \boldsymbol{\ell}_k, \mathbf{x}_{t_k})$ using the same on-policy confidence sampler we use at inference.</li> <li><strong>Token selection.</strong> <strong>Teacher-force</strong> the unmasked positions to the ground-truth values: $x_{t_{k+1}}^i = x_0^i$ for $i \in \mathcal{U}$. Without teacher-forcing the model is exposed to its own early mistakes, which it has no mechanism to correct during the rollout<d-cite key="kim_puma_2026"></d-cite>.</li> </ul> <p>This rollout scheme leaves the ideal minimizer unchanged from the standard MDM objective; intuitively, sampling positions on-policy and teacher-forcing their values does not alter what <em>the right answer is</em>, only the distribution over which contexts the model practices on.</p> <h3 id="two-step-gradients">Two-step gradients</h3> <p>In our experiments we use the <strong>minimal non-trivial truncation horizon</strong>, $K{=}2$: the model rolls out two consecutive steps, and the second step’s cross-entropy gradient flows back through the relay into the first step’s hidden state. This is enough to give the model an explicit incentive to encode information at step $k$ that pays off at step $k{+}1$, without paying a long-horizon BPTT bill.</p> <p>Concretely, the two-step gradient decomposes into the usual local term plus a BPTT correction routed through the relay path; we give the full adjoint recursion in §3.2 of the paper.</p> <h2 id="validating-the-design-on-sudoku">Validating the Design on Sudoku</h2> <p>We first study <span class="sc">Relay</span> on <strong>Sudoku-Extreme</strong><d-cite key="wang2025hierarchical"></d-cite>, a 9×9 logical-reasoning puzzle with a unique solution. Sudoku is a clean test bed: it has a small vocabulary, an obvious notion of <em>legality</em> at every partial state (no row/column/box repeats), and rich global constraints that should reward a model that can carry intermediate reasoning across steps.</p> <h3 id="setup">Setup</h3> <p>All methods use the <strong>same</strong> small Transformer backbone (~7M parameters) with rotary position embeddings, trained to convergence. To isolate the contribution of each ingredient of Algorithm 1 (Figure 2) in the paper, we compare four objectives, each in tied- and untied-embedding variants:</p> <ul> <li><strong>MLM</strong> — standard uniform masked diffusion ($K{=}1$, no relay).</li> <li><strong>Rollout</strong> — $K{=}2$ on-policy rollouts, but <strong>no</strong> relay ($R_\theta \equiv 0$).</li> <li><strong><span class="sc">Relay</span> (sg)</strong> — $K{=}2$ relay rollouts, but <strong>stop-gradient</strong> on the relay between steps (no temporal credit).</li> <li><strong><span class="sc">Relay</span></strong> — the full method: $K{=}2$ BPTT through the relay.</li> </ul> <p>At inference, we sweep deterministic confidence thresholds $\tau \in {0.05, 0.10, 0.15, 0.20, 0.25}$ and trace each method’s <em>accuracy–NFE</em> frontier. Lower $\tau$ commits fewer cells per forward pass and so spends more NFEs.</p> <h3 id="the-accuracynfe-frontier">The accuracy–NFE frontier</h3> <figure class="relay-sudoku-figure"> <div class="container-fluid px-0"> <div class="row align-items-start"> <div class="col-md-6 mb-3 mb-md-0 text-center"> <img class="relay-sudoku-panel img-fluid" src="/diffusion/assets/img/blog/relay/sudoku_pareto_seed_bands_token_accuracy.png" alt="Sudoku token accuracy versus average rollout steps (NFE)."/> </div> <div class="col-md-6 text-center"> <img class="relay-sudoku-panel img-fluid" src="/diffusion/assets/img/blog/relay/sudoku_pareto_seed_bands_exact_match.png" alt="Sudoku exact-match accuracy versus average rollout steps (NFE)."/> </div> </div> <div class="row justify-content-center mt-2"> <div class="col-12 col-lg-10 text-center"> <img class="relay-sudoku-legend img-fluid" src="/diffusion/assets/img/blog/relay/sudoku_pareto_seed_bands_legend.png" alt="Legend for Sudoku Pareto curves: MLM, Rollout, Relay (stop-gradient), Relay."/> </div> </div> </div> <figcaption> <strong>Figure 3.</strong> Accuracy–NFE frontiers on Sudoku-Extreme. Each curve traces a single training method as we sweep the inference confidence threshold; shaded ribbons are ±1 sample std. across three training seeds. </figcaption> </figure> <p>Reading the transitions in Figure 3 from baseline to best-performing:</p> <ol> <li><strong>MLM → Rollout.</strong> Replacing uniform masking with an on-policy confidence-thresholded sampler under teacher forcing already moves the frontier substantially. The lesson: aligning <em>what is masked during training</em> with <em>what gets revealed at inference</em> matters even before introducing any relay.</li> <li><strong>Rollout → <span class="sc">Relay</span> (sg).</strong> Adding the relay channel — even with a stop-gradient between steps — yields the next big jump. A continuous handoff between forward passes provides value even when it is not yet <em>trained</em> for that handoff.</li> <li><strong><span class="sc">Relay</span> (sg) → <span class="sc">Relay</span>.</strong> Replacing the stop-gradient with $K{=}2$ BPTT through the relay gives a further separation and the best frontier across thresholds. <span class="kw">Training the relay to be useful, not just letting one exist, is the dominant effect.</span></li> </ol> <h3 id="why-bptt-helps-the-relay">Why BPTT helps the relay</h3> <p>What does the relay <em>learn</em> under BPTT? We trace the separation between <strong><span class="sc">Relay</span></strong> and <strong><span class="sc">Relay</span> (sg)</strong> to a concrete behavior: at the same threshold $\tau$, <strong><span class="sc">Relay</span> commits more cells per forward pass while keeping the partial board legal</strong>. A board is <em>legal</em> when no row, column, or 3×3 box contains a repeated digit; legality is necessary (but not sufficient) for correctness, and is defined at every intermediate denoising step, not only at the end.</p> <p>At the matched threshold $\tau = 0.15$ on a deduction-only subset of 2,000 puzzles, <span class="sc">Relay</span> produces a fully legal final board <span class="kw"><strong>74.8%</strong></span> of the time versus <strong>70.7%</strong> for <span class="sc">Relay</span> (sg) <span class="kw">(+4.1 pp)</span>, and incurs <span class="kw"><strong>15% fewer</strong></span> row/column/box violations across the rollout (0.90 vs. 1.06 per puzzle). The improvement holds uniformly across both <em>Advanced</em> and <em>Master</em> difficulty strata.</p> <p>In other words, <strong>BPTT teaches the relay to keep the partial board self-consistent under more aggressive unmasking</strong>: at the same confidence threshold, <span class="sc">Relay</span> commits more cells per pass while still honoring the global constraints, so the rollout reaches the same accuracy in <strong>fewer total forward passes</strong> — exactly the outward shift of the frontier in Figure 3.</p> <h2 id="scaling-to-fast-dllm-v2">Scaling to Fast-dLLM v2</h2> <p>Sudoku tells us <em>what</em> <span class="sc">Relay</span> buys when trained from scratch. To see whether the same recipe holds at LLM scale, we <strong>adapt a pretrained DLM</strong> — Fast-dLLM v2 (1.5B)<d-cite key="wu2025fastdllmv2efficientblockdiffusion"></d-cite> — into a <span class="sc">Relay</span> model with a small amount of supervised finetuning.</p> <h3 id="block-diffusion-and-kv-caching">Block diffusion and KV caching</h3> <p>Fast-dLLM v2 combines <strong>block-autoregressive decoding</strong> in the style of Block Diffusion<d-cite key="arriola_block_2024"></d-cite> with KV caching for previously decoded blocks<d-cite key="wu2025fastdllm"></d-cite>. To make our two-step relay rollout compatible with this:</p> <ul> <li><strong>Per-block rollout.</strong> The $K{=}2$ rollout runs <strong>only inside the active block</strong>, leaving previously decoded blocks frozen so their inter-block KV cache is reused unchanged across both passes.</li> <li><strong>Mask-only relay updates.</strong> Within the active block we update the relay state only at positions that are <em>still masked</em>. Already-committed sub-block tokens still contribute attention, but their relay entries are not overwritten — keeping within-block sub-block KV cache entries valid as the block fills in.</li> </ul> <p>With these two adaptations, Algorithm 1 plugs in cleanly without breaking Fast-dLLM v2’s acceleration story.</p> <h3 id="coding-results">Coding results</h3> <p>We finetune all parameters for 200 optimizer steps at effective batch size 32 on a 60k-example 40/60 code/math mixture of OpenCodeInstruct and OpenMathInstruct-2 (details in the paper). Inference follows Fast-dLLM v2’s confidence-based parallel decoding at $\tau=0.85$.</p> <table> <thead> <tr> <th>Method</th> <th style="text-align: center">HumanEval (Base)</th> <th style="text-align: center">HumanEval (Plus)</th> <th style="text-align: center">HumanEval NFE</th> <th style="text-align: center">MBPP (Base)</th> <th style="text-align: center">MBPP (Plus)</th> <th style="text-align: center">MBPP NFE</th> </tr> </thead> <tbody> <tr> <td>Fast-dLLM v2 (1.5B)</td> <td style="text-align: center">38.4%</td> <td style="text-align: center">35.4%</td> <td style="text-align: center">178.1</td> <td style="text-align: center">46.8%</td> <td style="text-align: center">39.7%</td> <td style="text-align: center">133.0</td> </tr> <tr> <td>  Vanilla SFT</td> <td style="text-align: center">38.4%</td> <td style="text-align: center">34.1%</td> <td style="text-align: center">130.7</td> <td style="text-align: center">43.9%</td> <td style="text-align: center">38.1%</td> <td style="text-align: center">84.8</td> </tr> <tr> <td>  <span class="sc">Relay</span> (sg)</td> <td style="text-align: center">38.4%</td> <td style="text-align: center">35.4%</td> <td style="text-align: center">104.4</td> <td style="text-align: center">43.1%</td> <td style="text-align: center">39.2%</td> <td style="text-align: center">80.1</td> </tr> <tr> <td>  <strong><span class="sc">Relay</span></strong></td> <td style="text-align: center"><strong>42.1%</strong></td> <td style="text-align: center"><strong>37.2%</strong></td> <td style="text-align: center"><strong>88.3</strong></td> <td style="text-align: center"><strong>46.6%</strong></td> <td style="text-align: center"><strong>41.5%</strong></td> <td style="text-align: center"><strong>78.8</strong></td> </tr> </tbody> </table> <p>The same pattern from Sudoku reappears at LLM scale: <strong><span class="sc">Relay</span> attains the best accuracies <em>and</em> the lowest NFE</strong> among adapted methods on both HumanEval and EvalPlus<d-cite key="liu2023evalplus"></d-cite> variants. Notably, on HumanEval, <strong><span class="sc">Relay</span> surpasses Vanilla-SFT’s accuracy while spending <span class="kw">32% fewer NFEs</span></strong> (88.3 vs. 130.7) — the relay-trained model is <em>simultaneously</em> more accurate and faster than its non-relay counterpart finetuned on identical data.</p> <h2 id="training-memory">Training Memory</h2> <p>A reasonable worry is that BPTT through $K{=}2$ forward passes doubles training memory. We profiled one micro-step on an A100 80GB, and it does not — and the reason turns out to be <strong>structural</strong>, rooted in how Fast-dLLM v2’s <em>vanilla SFT</em> forward is shaped versus how <span class="sc">Relay</span>’s rollout forwards are shaped.</p> <figure> <img src="/diffusion/assets/img/blog/relay/gpu_memory_vanilla_vs_relay_c40m60_rank0.png" alt="GPU memory: vanilla SFT vs. Relay during one micro-step."/> <figcaption> <strong>Figure 4.</strong> GPU memory during one training micro-step of Fast-dLLM v2 on an A100 80GB. Solid lines show live GPU memory at every decoder-layer fwd/bwd hook; dashed lines show the running high-water mark within the micro-step. Note that <span class="sc">Relay</span>'s first forward (red <code>fwd1</code>) plateaus at roughly <em>half</em> the forward-pass footprint of vanilla SFT's single forward (black <code>fwd1</code>) — the direct consequence of the input-shape asymmetry below. </figcaption> </figure> <p><strong>The shape asymmetry.</strong> Both regimes inherit BD3-LM’s<d-cite key="arriola_block_2024"></d-cite> $[\mathbf{x}_t \,|\, \mathbf{x}_0]$ layout, which concatenates the noised and clean sequences and so doubles the <strong>sequence</strong> dimension to $2L$. Vanilla SFT, however, <em>additionally</em> makes a complementary-mask copy along the <strong>batch</strong> dimension, doubling it to $2B$. So a single vanilla forward runs at shape <span class="kw">$(2B, 2L)$</span>, whereas each of <span class="sc">Relay</span>’s two rollout forwards runs at the <em>undoubled</em> batch, <span class="kw">$(B, 2L)$</span>. That is exactly why, in Figure 4, <span class="sc">Relay</span>’s red <code class="language-plaintext highlighter-rouge">fwd1</code> adds about half the activation memory of vanilla’s black <code class="language-plaintext highlighter-rouge">fwd1</code>: same $2L$, half the batch. Two <span class="sc">Relay</span> forwards at $(B, 2L)$ together demand memory comparable to vanilla’s <em>one</em> forward at $(2B, 2L)$.</p> <p><strong>Two asymmetries that roughly cancel.</strong> The <strong>binding peak</strong> of the micro-step in both regimes is the <strong>cross-entropy backward through the vocab projection head</strong> (<code class="language-plaintext highlighter-rouge">lm_head</code>), which materializes a $B \times T \times V$ fp32 grad-of-logits buffer. Because vanilla carries the doubled batch into that backward, its buffer is <em>twice as large</em> as <span class="sc">Relay</span>’s: at per-device batch size 2 with $T=2048$, $V=151{,}936$, it is $[4, 2048, 151936]\times 4\,\text{B} \approx 4.6$ GiB for vanilla versus only $\approx 2.3$ GiB for <span class="sc">Relay</span>. <span class="sc">Relay</span> spends its savings on the activation side instead: its second forward raises live memory by roughly 5 GiB through <code class="language-plaintext highlighter-rouge">fwd2</code> (forward 1’s saved activations and the relay state $\mathbf{h}$ must coexist with forward 2 to route credit through both passes), though most of that is autograd intermediates rather than saved-for-backward state, and PyTorch releases it in a single step <em>before</em> the <code class="language-plaintext highlighter-rouge">lm_head</code> spike fires.</p> <p>The result: peak memory is <strong><span class="kw">20.1 GiB for <span class="sc">Relay</span> vs. 21.2 GiB for vanilla SFT</span></strong> — within ~1 GiB. The near-tie is <strong>structural, not allocator-timing noise</strong>: <span class="sc">Relay</span> simply trades vanilla’s in-forward batch doubling for an explicit second pass. BPTT through $K{=}2$ therefore does <strong>not</strong> double peak memory in this setup (gradient checkpointing, ZeRO-3, non-fused CE), and we expect peak memory to stay comparable whenever the vanilla baseline already pays a doubled-batch forward and the <code class="language-plaintext highlighter-rouge">lm_head</code> backward dominates.</p> <h2 id="takeaways">Takeaways</h2> <p>The hard-reset bottleneck in masked diffusion has a clean fix: give the model a <strong>differentiable channel between steps</strong>, and <em>train it to be useful</em> with truncated BPTT.</p> <p>Stepping back:</p> <ul> <li>A continuous relay channel is the structural minimum needed to break the hard reset. Even with a stop-gradient it already improves the Pareto frontier.</li> <li>BPTT through the relay is what makes the relay <em>learn what to remember</em>. On Sudoku, this shows up as more aggressive but still legal unmasking; on Fast-dLLM v2, as simultaneously higher accuracy and fewer NFEs.</li> <li>The recipe is <strong>architecture-agnostic</strong>, leaves the inference decoding procedure unchanged, and is compatible with prevalent DLM acceleration techniques like block diffusion and KV caching — so it composes with the same engineering stack used in state-of-the-art diffusion LLMs.</li> </ul> <p>If you found this useful, please consider citing our paper:</p> <div class="language-bibtex highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nc">@inproceedings</span><span class="p">{</span><span class="nl">rozonoyer2026relay</span><span class="p">,</span>
  <span class="na">title</span>     <span class="p">=</span> <span class="s">{Learned Relay Representations for Forward-Thinking Discrete Diffusion Models}</span><span class="p">,</span>
  <span class="na">author</span>    <span class="p">=</span> <span class="s">{Rozonoyer, Benjamin and Minniti, Jacopo and Patel, Dhruvesh and Band, Neil and Bose, Joey and Rudner, Tim G. J. and McCallum, Andrew}</span><span class="p">,</span>
  <span class="na">booktitle</span> <span class="p">=</span> <span class="s">{Structured Probabilistic Inference \&amp; Generative Modeling (SPIGM) and Frontiers in Generative AI (FoGen) Workshops at ICML}</span><span class="p">,</span>
  <span class="na">year</span>      <span class="p">=</span> <span class="s">{2026}</span>
<span class="p">}</span>
</code></pre></div></div>]]></content><author><name>Benjamin Rozonoyer</name></author><category term="discrete-diffusion"/><category term="masked-diffusion"/><category term="bptt"/><category term="language-models"/><summary type="html"><![CDATA[Forward-thinking discrete diffusion via a learned latent channel trained with truncated BPTT.]]></summary></entry></feed>