Better Estimation of the Kullback-Leibler Divergence Between Language Models¶
会议: NeurIPS 2025 arXiv: 2504.10637 代码: https://github.com/rycolab/kl-rb 领域: LLM 训练 / 统计估计 关键词: KL散度估计, Rao-Blackwell化, RLHF, 方差减少, 语言模型
一句话总结¶
提出 KL 散度的 Rao-Blackwell 化 Monte Carlo 估计器——在每个位置对下一个 token 的分布求精确 KL(而非只用采样的 token),理论证明无偏且方差严格不超过标准 MC 估计器,零额外计算开销,在 RLHF 情感控制任务中使训练更稳定、模型更频繁出现在 Pareto 前沿(78%)。
研究背景与动机¶
-
领域现状:KL 散度在 RLHF(正则化项)、可解释性(分布偏移度量)、知识蒸馏中广泛使用。精确计算语言模型间 KL 不可行(\(\Sigma^*\) 可数无穷)。
-
现有痛点:(a)标准 MC 估计器 \(\mu_{mc} = \frac{1}{M}\sum_m \log\frac{p(Y^{(m)})}{q(Y^{(m)})}\) 方差高、可能为负值;(b)Schulman 提出的控制变量方法(\(\alpha=1\))保证非负但方差可能爆炸(实验证实);(c)训练中 KL 估计不稳定会导致 RLHF 不稳定。
-
核心矛盾:MC 估计器只用采样的完整字符串计算 log-ratio,浪费了前向传播已经产生的每个位置上完整的 next-token 分布信息。
-
本文要解决什么? 在零额外计算开销下显著降低 KL 估计方差。
-
切入角度:Rao-Blackwell 化——在每个位置 \(n\) 对 next-token 分布求精确 KL(\(|\bar{\Sigma}|\) 个 token 的求和),而非只用采样到的那个 token。
-
核心 idea:\(\mu_{rb} = \frac{1}{M}\sum_m \sum_{n=1}^{|Y^{(m)}|} KL(\vec{p}(\cdot|Y^{(m)}_{<n}) \| \vec{q}(\cdot|Y^{(m)}_{<n}))\)。
方法详解¶
整体框架¶
标准 MC 在采样字符串级别计算 log-ratio。RB 估计器在每个位置对完整 next-token 分布计算精确 KL,然后对位置求和、对样本求平均。前向传播已经产生了每个位置的完整分布,所以额外计算仅为 \(O(MN|\bar{\Sigma}|)\)(可忽略)。
关键设计¶
- Rao-Blackwell 化 KL 估计器:
- 做什么:在每个 token 位置精确计算分布级 KL 而非仅用采样 token
- 核心思路:将 MC 估计器的每一步 \(\log\frac{\vec{p}(Y_n|Y_{<n})}{\vec{q}(Y_n|Y_{<n})}\) 替换为 \(KL(\vec{p}(\cdot|Y_{<n}) \| \vec{q}(\cdot|Y_{<n}))\)——对整个 vocabulary 求和而非只看采样到的 token
-
设计动机:Rao-Blackwell 定理保证条件期望的方差不超过原始方差。方差减少来源:不再依赖采样到的特定 token,而是利用完整分布信息
-
理论保证(Theorem 2):
- 无偏性:\(\mathbb{E}[\mu_{rb}] = KL(p \| q)\)
- 方差减少:\(Var[\mu_{rb}] \leq Var[\mu_{mc}]\)
-
非负性:每项都是精确 KL(非负),所以总和非负——不像 MC 可能产生负值
-
RB 梯度估计器:
- 做什么:将 RB 扩展到 KL 梯度估计(用于 RLHF 训练循环)
- 核心思路:通过 Theorem 4 推导 KL 梯度的 local decomposition,然后 RB 化得到 \(\delta_{rb}\)(Theorem 5 证明 \(\mathbb{E}[\|\delta_{rb} - G\|^2] \leq \mathbb{E}[\|\delta_{mc} - G\|^2]\))
- 设计动机:RLHF 循环中梯度的方差直接影响训练稳定性
损失函数 / 训练策略¶
- RLHF 目标:期望奖励 - \(\beta \cdot KL(p_\theta \| q)\)
- 用 RB 估计器替换 MC 估计器计算 KL 项及其梯度
实验关键数据¶
KL 估计质量(GPT-2 情感控制)¶
| 估计器 | M=1 均值±std | M=5 均值±std | M=10 均值±std |
|---|---|---|---|
| \(\mu_{mc}\) | 6.76±0.16 | 6.76±0.07 | 6.76±0.05 |
| \(\mu_{cv}\) (\(\alpha=1\)) | 6.28±2.54 | 6.28±1.13 | 6.28±0.79 |
| \(\mu_{cv}\) (最优\(\alpha\)) | 6.76±0.16 | 6.76±0.07 | 6.76±0.05 |
| \(\mu_{rb}\) | 6.76±0.11 | 6.76±0.05 | 6.76±0.03 |
RB 在所有样本量下标准差最低。Schulman 的 CV (\(\alpha=1\)) 有偏且方差爆炸。
RLHF 训练稳定性¶
| 估计器 | 奖励稳定性 | KL 稳定性 | Pareto 前沿占比 |
|---|---|---|---|
| MC | 高方差(5 次实验差异大) | 高方差 | 22% |
| RB | 稳定 | 稳定 | 78% |
KL < 5 区域:RB 模型占 Pareto 前沿的 95%!
梯度方差实验¶
| 估计器 | 梯度范数方差 |
|---|---|
| MC | 59.90 |
| RB | 45.44 (-24.6%) |
关键发现¶
- RB 是唯一实现有意义方差减少的方法:CV 方法在大多数 prompt 上无法减少方差甚至增加
- \(\alpha=1\) 的 Schulman 估计器有严重问题:某些 prompt 上 \(Var[g(Y)]\) 无界,导致偏差和巨大方差
- RB 使 RLHF 显著更稳定:5 次重复实验的奖励和 KL 曲线几乎重叠(MC 方差很大)
- 零额外计算开销:前向传播已经产生了完整分布
亮点与洞察¶
- "已有信息不用白不用"的深刻洞察——前向传播产生了完整的 next-token 分布,MC 只用了采样到的一个 token,RB 用了全部
- 非负性是免费赠品:MC 可能为负(导致 RLHF 不稳定),Schulman 方法牺牲方差换非负。RB 既非负又低方差
- 对 RLHF 开源库(trl/OpenRLHF 等)有直接可使用的改进——代码片段已提供
局限性 / 可改进方向¶
- RLHF 实验仅在 GPT-2 上验证(计算资源限制导致需要训练 36 个模型做显著性检验)
- 词表大小 \(|\bar{\Sigma}|\) 很大时精确 KL 计算仍有开销(虽然远小于前向传播)
- 假设 \(KL(p \| q) < \infty\)(实践中可能不满足)
相关工作与启发¶
- vs Schulman (2020) 控制变量:\(\alpha=1\) 保证非负但方差可能无界。RB 同时非负+低方差
- vs Horvitz-Thompson:另一种无偏估计器但无方差改善
- vs 解析计算:仅对特殊模型(如 PFSA)可行。RB 对任何神经语言模型适用
评分¶
- 新颖性: ⭐⭐⭐⭐ Rao-Blackwell 化是经典技术但在 LM KL 估计的应用新颖且关键
- 实验充分度: ⭐⭐⭐⭐⭐ 估计质量+RLHF 训练+Pareto 分析+梯度方差
- 写作质量: ⭐⭐⭐⭐⭐ 理论推导极其清晰,定理-证明结构优美
- 价值: ⭐⭐⭐⭐⭐ 对 RLHF 实践有直接可落地的改进