跳转至

Generalized Interpolating Discrete Diffusion

会议: ICML2025
arXiv: 2503.04482
代码: dvruette/gidd
领域: 离散扩散模型 / 语言建模
关键词: 离散扩散, 掩码扩散, 均匀噪声, 自纠错, ELBO, 语言模型

一句话总结

提出广义插值离散扩散框架 GIDD,将掩码扩散 (MDM) 推广为支持任意时变混合分布的扩散族,通过结合掩码与均匀噪声赋予模型自纠错能力,在扩散语言建模中取得 compute-matched SOTA。

研究背景与动机

自回归语言模型(如 GPT)在生成过程中按顺序逐 token 预测,存在两个根本局限:(1) 生成长度 \(N\) 的序列必须 \(N\) 次前向推理;(2) 已生成的 token 无法修改,一旦出错则错误不可逆地传播到后续。

离散扩散模型通过逐步向数据添加噪声再学习反转过程来生成,将推理步数与序列长度解耦。然而目前最流行的掩码扩散模型 (MDM) 只使用 [MASK] token 作为噪声——token 一旦被填入就不再改变,本质上重新引入了自回归的"不可逆"问题。

本文的核心动机:如果在扩散过程中混入一部分均匀噪声(将 token 随机替换为其他 token 而非掩码),模型就需要学会区分"正确"与"错误" token,从而获得自纠错能力。为此需要一个统一的理论框架来支持任意类型的噪声设计。

方法详解

GIDD 前向过程

GIDD 将掩码扩散推广为广义插值形式。给定数据 \(x\),时刻 \(t\) 的边际转移为:

\[q_t(z_t|x) = \mathrm{Cat}(z_t;\;\alpha_t \mathbf{x} + \beta_t \boldsymbol{\pi}_t)\]

其中 \(\alpha_t\) 为信噪比调度(从 1 单调递减到 0),\(\boldsymbol{\pi}_t\)时变混合分布,可以是任意概率分布。当 \(\boldsymbol{\pi}_t = \mathbf{m}\)(掩码的 one-hot)时退化为标准 MDM。

论文证明了存在马尔可夫链使得上述边际成立,并推导出累积转移矩阵的闭式解:

\[Q_t = \alpha_t I + \beta_t \boldsymbol{\pi}_t \mathbf{1}^\top\]

混合调度设计

实际采用掩码 + 均匀噪声的混合:

\[q_t(z_t|x) = \frac{1}{C_t}\big((1-t)\mathbf{x} + t\mathbf{m} + c_t \mathbf{u}\big)\]

其中 \(\mathbf{u} = \frac{1}{N-1}(\mathbf{1}-\mathbf{m})\) 为均匀分布,\(c_t = Bt^{\gamma/2}(1-t)^{\gamma/2}\) 控制均匀噪声量,\(p_u\)\(t=0.5\) 时均匀 token 的期望比例。\(p_u=0\) 回退到纯掩码扩散。

GIDD ELBO

基于连续时间马尔可夫链 (CTMC),推导出通用 ELBO,由两部分组成:

  1. KL 散度项:模型预测分布 \(q_t(\cdot|\mathbf{x}_\theta)\) 与真实条件分布 \(q_t(\cdot|x)\) 之间的 KL 散度
  2. IS 散度项:在采样点 \(z_t\) 处的逐点散度

两项同时为零当且仅当模型完美匹配真实分布,保证 ELBO 有全局最小值。

损失权重重整

ELBO 权重 \(w_t(z_t, x)\)\(t\to 0\)\(t\to 1\) 时指数增长,导致优化不稳定。论文提出两种方案:

  • Clamp: \(\tilde{w}_t^{\mathrm{clamp}} = \min(w_{\max}, w_t)\),简单截断,\(w_{\max}=1\)
  • Dynamic: \(\tilde{w}_t^{\mathrm{dyn}} = w_{\max}(1 + \delta_{z_t,m} + (\frac{B}{N}e^{-\lambda_t/2}-1)\delta_{z_t,x})\),保持不同 token 类型间的相对权重

Dynamic 权重 + weight decay (0.02) 的组合(称为 GIDD+)效果最佳。

自纠错采样

生成完成后,将全部已去噪序列 \(Z_{t_0}\) 送入模型,以温度 \(\tau\) 重采样;在所有与原始不同的 token 中选取模型置信度最高的一个进行替换。迭代此过程直至收敛。

实验关键数据

在 OpenWebText 上训练 110M (small) / 320M (base) 模型,使用 DiT 架构 + GPT2 tokenizer。

验证困惑度 (PPL)

模型 训练 tokens PPL ↓
GPT2 - 23.40
Llama 110M (retrain) 262B 16.11
MDLM 262B 23.21
MDM (reimpl.) 262B 23.36
GIDD+ (p_u=0.0) 262B 22.29

消融:权重方案的影响

方案 p_u=0.0 p_u=0.1 p_u=0.2
无重整 24.36 26.88 28.22
+ clamp 23.23 25.09 26.40
+ dynamic 23.24 23.90 24.64
+ weight decay (GIDD+) 23.05 23.67 24.38

自纠错效果 (base 模型)

\(p_u=0.2\) 模型经自纠错后:生成 PPL 从 214 降至 93.3(↓56%),self-accuracy 从 62.0% 升至 73.5%。纯掩码模型自纠错反而使质量下降。

GPT-4o 质量评分

模型 清晰度 语法 事实性 文风 创造性
GIDD (p_u=0.0) + 自纠错 -20.9% -19.3% -16.2% -21.1% -19.5%
GIDD (p_u=0.2) + 自纠错 +16.5% +16.6% +8.5% +13.4% +5.5%

零样本基准

GIDD+ (p_u=0.0) 在 ARC/BoolQ/Hellaswag/PIQA 等 7 个基准的平均准确率 39.30,超过 GPT2-small (38.77) 和 MDM (38.25)。

亮点与洞察

  1. 理论优雅:GIDD 是掩码扩散的严格推广,推导出闭式累积转移和 ELBO,有完整的全局最优性证明
  2. 自纠错是离散扩散相比自回归的真正差异化能力——不是简单的后处理,而是训练时通过均匀噪声自然获得
  3. 损失权重重整带来显著增益:dynamic weighting 将 \(p_u=0.2\) 的 PPL 从 28.22 拉到 24.64
  4. 低推理步数优势:32 步去噪时,\(p_u=0.1\) 的生成 PPL (387) 远优于纯掩码 (904)
  5. 似然评估 (PPL/benchmark) 与生成质量并不完全一致——掩码模型似然更好但生成质量更差,提示扩散语言模型需要更全面的评估体系

局限与展望

  1. 规模有限:最大 320M 参数,在大规模下混合噪声的效果尚不确定
  2. 均匀噪声带来的 PPL 损失需要更多容量来弥补,scaling 行为需进一步验证
  3. 自纠错依赖迭代推理,增加了生成延迟
  4. 生成质量评价依赖"生成 PPL"(用更大模型打分),该指标本身可能有偏差
  5. 混合调度的超参(\(p_u\)\(\gamma\)\(B\))设计仍较启发式,搜索空间未充分探索
  6. 仅在语言建模验证,对代码生成、蛋白质序列等其他离散数据的适用性待探索

相关工作与启发

  • Austin et al. (2023): 首个将扩散 ELBO 引入离散马尔可夫链
  • MDLM / MD4 (Sahoo et al., 2024; Shi et al., 2024): 简化掩码扩散目标
  • BERT (Devlin et al., 2019): 掩码+随机替换的预训练启发了 GIDD 的混合噪声设计
  • 离散流匹配 (Gat et al., 2024): 将 flow matching 范式适配到离散数据

评分

  • 新颖性: ⭐⭐⭐⭐ — 将掩码扩散推广为统一框架并引入可控均匀噪声,方向新颖
  • 实验充分度: ⭐⭐⭐⭐ — 消融全面,既有似然又有生成质量,但规模受限
  • 写作质量: ⭐⭐⭐⭐⭐ — 理论推导清晰严谨,图表直观
  • 价值: ⭐⭐⭐⭐ — 自纠错是离散扩散的重要突破,框架灵活性强

相关论文