跳转至

Better Estimation of the Kullback-Leibler Divergence Between Language Models

会议: NeurIPS 2025 arXiv: 2504.10637 代码: https://github.com/rycolab/kl-rb 领域: LLM 训练 / 统计估计 关键词: KL散度估计, Rao-Blackwell化, RLHF, 方差减少, 语言模型

一句话总结

提出 KL 散度的 Rao-Blackwell 化 Monte Carlo 估计器——在每个位置对下一个 token 的分布求精确 KL(而非只用采样的 token),理论证明无偏且方差严格不超过标准 MC 估计器,零额外计算开销,在 RLHF 情感控制任务中使训练更稳定、模型更频繁出现在 Pareto 前沿(78%)。

研究背景与动机

  1. 领域现状:KL 散度在 RLHF(正则化项)、可解释性(分布偏移度量)、知识蒸馏中广泛使用。精确计算语言模型间 KL 不可行(\(\Sigma^*\) 可数无穷)。

  2. 现有痛点:(a)标准 MC 估计器 \(\mu_{mc} = \frac{1}{M}\sum_m \log\frac{p(Y^{(m)})}{q(Y^{(m)})}\) 方差高、可能为负值;(b)Schulman 提出的控制变量方法(\(\alpha=1\))保证非负但方差可能爆炸(实验证实);(c)训练中 KL 估计不稳定会导致 RLHF 不稳定。

  3. 核心矛盾:MC 估计器只用采样的完整字符串计算 log-ratio,浪费了前向传播已经产生的每个位置上完整的 next-token 分布信息。

  4. 本文要解决什么? 在零额外计算开销下显著降低 KL 估计方差。

  5. 切入角度:Rao-Blackwell 化——在每个位置 \(n\) 对 next-token 分布求精确 KL(\(|\bar{\Sigma}|\) 个 token 的求和),而非只用采样到的那个 token。

  6. 核心 idea\(\mu_{rb} = \frac{1}{M}\sum_m \sum_{n=1}^{|Y^{(m)}|} KL(\vec{p}(\cdot|Y^{(m)}_{<n}) \| \vec{q}(\cdot|Y^{(m)}_{<n}))\)

方法详解

整体框架

标准 MC 在采样字符串级别计算 log-ratio。RB 估计器在每个位置对完整 next-token 分布计算精确 KL,然后对位置求和、对样本求平均。前向传播已经产生了每个位置的完整分布,所以额外计算仅为 \(O(MN|\bar{\Sigma}|)\)(可忽略)。

关键设计

  1. Rao-Blackwell 化 KL 估计器:
  2. 做什么:在每个 token 位置精确计算分布级 KL 而非仅用采样 token
  3. 核心思路:将 MC 估计器的每一步 \(\log\frac{\vec{p}(Y_n|Y_{<n})}{\vec{q}(Y_n|Y_{<n})}\) 替换为 \(KL(\vec{p}(\cdot|Y_{<n}) \| \vec{q}(\cdot|Y_{<n}))\)——对整个 vocabulary 求和而非只看采样到的 token
  4. 设计动机:Rao-Blackwell 定理保证条件期望的方差不超过原始方差。方差减少来源:不再依赖采样到的特定 token,而是利用完整分布信息

  5. 理论保证(Theorem 2):

  6. 无偏性:\(\mathbb{E}[\mu_{rb}] = KL(p \| q)\)
  7. 方差减少:\(Var[\mu_{rb}] \leq Var[\mu_{mc}]\)
  8. 非负性:每项都是精确 KL(非负),所以总和非负——不像 MC 可能产生负值

  9. RB 梯度估计器:

  10. 做什么:将 RB 扩展到 KL 梯度估计(用于 RLHF 训练循环)
  11. 核心思路:通过 Theorem 4 推导 KL 梯度的 local decomposition,然后 RB 化得到 \(\delta_{rb}\)(Theorem 5 证明 \(\mathbb{E}[\|\delta_{rb} - G\|^2] \leq \mathbb{E}[\|\delta_{mc} - G\|^2]\)
  12. 设计动机:RLHF 循环中梯度的方差直接影响训练稳定性

损失函数 / 训练策略

  • RLHF 目标:期望奖励 - \(\beta \cdot KL(p_\theta \| q)\)
  • 用 RB 估计器替换 MC 估计器计算 KL 项及其梯度

实验关键数据

KL 估计质量(GPT-2 情感控制)

估计器 M=1 均值±std M=5 均值±std M=10 均值±std
\(\mu_{mc}\) 6.76±0.16 6.76±0.07 6.76±0.05
\(\mu_{cv}\) (\(\alpha=1\)) 6.28±2.54 6.28±1.13 6.28±0.79
\(\mu_{cv}\) (最优\(\alpha\)) 6.76±0.16 6.76±0.07 6.76±0.05
\(\mu_{rb}\) 6.76±0.11 6.76±0.05 6.76±0.03

RB 在所有样本量下标准差最低。Schulman 的 CV (\(\alpha=1\)) 有偏且方差爆炸。

RLHF 训练稳定性

估计器 奖励稳定性 KL 稳定性 Pareto 前沿占比
MC 高方差(5 次实验差异大) 高方差 22%
RB 稳定 稳定 78%

KL < 5 区域:RB 模型占 Pareto 前沿的 95%

梯度方差实验

估计器 梯度范数方差
MC 59.90
RB 45.44 (-24.6%)

关键发现

  • RB 是唯一实现有意义方差减少的方法:CV 方法在大多数 prompt 上无法减少方差甚至增加
  • \(\alpha=1\) 的 Schulman 估计器有严重问题:某些 prompt 上 \(Var[g(Y)]\) 无界,导致偏差和巨大方差
  • RB 使 RLHF 显著更稳定:5 次重复实验的奖励和 KL 曲线几乎重叠(MC 方差很大)
  • 零额外计算开销:前向传播已经产生了完整分布

亮点与洞察

  • "已有信息不用白不用"的深刻洞察——前向传播产生了完整的 next-token 分布,MC 只用了采样到的一个 token,RB 用了全部
  • 非负性是免费赠品:MC 可能为负(导致 RLHF 不稳定),Schulman 方法牺牲方差换非负。RB 既非负又低方差
  • 对 RLHF 开源库(trl/OpenRLHF 等)有直接可使用的改进——代码片段已提供

局限性 / 可改进方向

  • RLHF 实验仅在 GPT-2 上验证(计算资源限制导致需要训练 36 个模型做显著性检验)
  • 词表大小 \(|\bar{\Sigma}|\) 很大时精确 KL 计算仍有开销(虽然远小于前向传播)
  • 假设 \(KL(p \| q) < \infty\)(实践中可能不满足)

相关工作与启发

  • vs Schulman (2020) 控制变量\(\alpha=1\) 保证非负但方差可能无界。RB 同时非负+低方差
  • vs Horvitz-Thompson:另一种无偏估计器但无方差改善
  • vs 解析计算:仅对特殊模型(如 PFSA)可行。RB 对任何神经语言模型适用

评分

  • 新颖性: ⭐⭐⭐⭐ Rao-Blackwell 化是经典技术但在 LM KL 估计的应用新颖且关键
  • 实验充分度: ⭐⭐⭐⭐⭐ 估计质量+RLHF 训练+Pareto 分析+梯度方差
  • 写作质量: ⭐⭐⭐⭐⭐ 理论推导极其清晰,定理-证明结构优美
  • 价值: ⭐⭐⭐⭐⭐ 对 RLHF 实践有直接可落地的改进