Remasking Discrete Diffusion Models with Inference-Time Scaling¶
会议: NeurIPS 2025
arXiv: 2503.00307
代码: https://github.com/guanghanwang/remdm (有)
领域: Image Generation / Discrete Diffusion
关键词: 离散扩散模型, 重掩码采样, 推理时计算缩放, 迭代精修, 可控生成
一句话总结¶
提出 ReMDM 采样器,通过在生成过程中允许已解码 token 被重新掩码(remask),赋予离散掩码扩散模型迭代纠错能力,实现推理时计算缩放,在文本、图像和分子设计任务上显著提升采样质量。
研究背景与动机¶
扩散模型在图像和视频生成上取得了巨大成功,其核心优势之一是迭代精修——在多步生成中反复修正输出、修复错误。然而,当前最先进的离散扩散模型(特别是基于掩码/吸收态的 MDLM)存在一个根本性限制:一旦 token 被解码,就无法再更新(failure to remask property)。即使解码时引入了错误,该 token 也会被"锁定",类似于自回归模型的顺序生成局限。
这个限制带来三个问题:(1)并行解码时token间的独立性导致不一致错误(如"She sell"而非"She sells"或"They sell");(2)无法通过增加采样步数来有效提升质量(推理时计算缩放受限);(3)可控生成能力受限。
本文的核心 idea 是:设计一个新的后验分布,允许已解码的 token 以一定概率被重新掩码,从而实现迭代纠错。关键的巧妙之处在于,这个新后验保持了与 MDLM 相同的边际分布,因此可以直接复用预训练的 MDLM 权重,无需重新训练。
方法详解¶
整体框架¶
ReMDM 在 MDLM 的基础上,引入一个参数 σ_t 控制重掩码概率。当 σ_t=0 时退化为标准 MDLM;当 σ_t>0 时,已解码的 token 有概率被重新掩码,从而获得迭代精修的能力。整个方法可以直接作为采样器应用于预训练的 MDLM 模型上。
关键设计¶
-
重掩码后验(Remasking Posterior):对于已解码的 token z_t ≠ m,后验为 q(z_s|z_t=x, x) = (1-σ_t)x + σ_t·m,即以 σ_t 的概率重新掩码。对于仍被掩码的 token,后验经过精心设计以保持与 MDLM 相同的边际分布 q(z_t|x)。这是一个核心定理(Theorem 3.1),保证了可以复用预训练权重。ReMDM 是一个非马尔可夫过程(类似于 DDIM 之于 DDPM 的关系)。
-
σ_t 的设计策略:作者提出了多种策略:
- Max-Capped (ReMDM-cap):将 σ_t 上限截断为常数 η_cap
- Rescaled (ReMDM-rescale):σ_t = η_rescale · σ_t^max,通过缩放因子控制重掩码强度
- Confidence-Based (ReMDM-conf):根据模型对每个 token 预测的置信度分配重掩码概率——置信度低的 token 更可能被重掩码,这是一个非常直觉的设计
-
开关策略(Turn On/Off):
- Switch:在 t_switch 之后才开启重掩码,先用标准 MDLM 生成初步候选
- Loop:分三个阶段——(1) 标准 MDLM 解码;(2) 保持 α 不变,在循环中反复重掩码和预测(纠错阶段);(3) 用标准 MDLM 解码剩余 token。这是最强的策略,直觉是先生成草稿、再反复修改
-
与预测器-校正器的关系:作者证明 ReMDM 更通用——FB 校正器和 DFM 校正器都是 ReMDM 的特例(Proposition 4.2, 4.3),且 ReMDM 可以处理 α_t 为常数的情况(Proposition 4.4),这是 DFM 无法做到的。
损失函数 / 训练策略¶
ReMDM 的 NELBO 为 MDLM 目标的一个重加权版本,仅多了 (1-σ_t) 因子。由于 σ_t=0 时退化为 MDLM,且实验验证训练目标相似时性能可比,推荐方案是直接复用预训练 MDLM 权重,仅在推理时使用 ReMDM 采样器,无需重训。
实验关键数据¶
主实验¶
| 任务/数据集 | 指标 | ReMDM | 之前SOTA | 提升 |
|---|---|---|---|---|
| OWT 文本生成 (T=4096) | MAUVE ↑ | 0.656 | 0.269 (DFM) | 2.4× |
| OWT 文本生成 (T=1024) | MAUVE ↑ | 0.403 | 0.254 (DFM) | 1.6× |
| OWT 快速采样 (T=512) | MAUVE ↑ | 0.350 | 0.211 (DFM) | 1.7× |
| ImageNet 生成 (T=64) | FID ↓ | 4.45 | 4.69 (MDLM) | 更优 |
| ImageNet 生成 (T=64) | IS ↑ | 209.45 | 196.38 (MaskGiT) | +6.7% |
| LLaDA Countdown | pass@1% | 46.1 | 45.2 (LLaDA) | +0.9 |
消融实验¶
| 配置 | MAUVE (OWT T=4096) | 说明 |
|---|---|---|
| MDLM (baseline) | 0.035 | 无重掩码 |
| + 64-bit precision | 提升 | 避免多样性受限 |
| + Nucleus sampling | 提升 | Top-p=0.9 关键质量提升 |
| + ReMDM remasking | 大幅提升 | 最大贡献来源 |
| + Loop strategy | 0.656 | 完整 ReMDM |
关键发现¶
- ReMDM 在推理时计算缩放方面表现优异:增加步数 T 能持续提升质量,而 MDLM 和校正器方法会饱和
- 在分子设计任务中,ReMDM 将可控生成的 novelty-property Pareto 前沿推向更优区域
- 将 ReMDM 应用于 LLaDA 8B 等大型 dLLM 时,也能显著提升下游任务表现
亮点与洞察¶
- 零额外训练成本:直接复用预训练 MDLM 权重,仅改变采样策略即可获得显著提升,这极大降低了使用门槛
- 理论优雅:通过保持边际分布不变,构建了与 MDLM 兼容的非马尔可夫过程,同时统一了 FB、DFM 等校正器为特例
- DDIM 的离散版本:ReMDM 之于 MDLM 就像 DDIM 之于 DDPM,为离散扩散提供了更灵活的采样空间
局限与展望¶
- 重掩码策略(σ_t schedule、loop 参数)的超参搜索空间较大,需要针对不同任务调参
- 重掩码增加了采样步数的有效使用,但相应地需要更多计算步来达到最佳效果
- 目前主要验证在文本、离散化图像和分子上,更大规模语言模型和连续信号上的扩展有待探索
相关工作与启发¶
- vs DDIM: DDIM 在连续域通过非马尔可夫过程实现灵活采样;ReMDM 将这一思想迁移到离散域的吸收态扩散
- vs FB/DFM correctors: ReMDM 更通用,能处理 αt 为常数的情况(loop策略),这些校正器做不到
- vs MaskGiT: MaskGiT 基于模型置信度解码但不允许重掩码;ReMDM的confidence-based schedule 融合了两者优势
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ 从概率模型角度优雅地解决了离散扩散的核心限制,理论推导完整
- 实验充分度: ⭐⭐⭐⭐⭐ 覆盖文本、图像、分子设计三个领域,加上LLaDA大模型验证
- 写作质量: ⭐⭐⭐⭐⭐ 动机清晰、理论推导严谨、实验展示充分
- 价值: ⭐⭐⭐⭐⭐ 零训练成本即可提升离散扩散模型性能,对 dLLM 社区有直接影响
相关论文¶
- [NeurIPS 2025] Inference-Time Scaling for Flow Models via Stochastic Generation and Rollover Budget Forcing
- [NeurIPS 2025] Learnable Sampler Distillation for Discrete Diffusion Models
- [NeurIPS 2025] Progressive Inference-Time Annealing of Diffusion Models for Sampling from Boltzmann Densities
- [NeurIPS 2025] Non-Markovian Discrete Diffusion with Causal Language Models
- [NeurIPS 2025] Constrained Discrete Diffusion