跳转至

Q♯: Provably Optimal Distributional RL for LLM Post-Training

会议: NeurIPS 2025
arXiv: 2502.20548
代码: https://github.com/jinpz/q_sharp
领域: 强化学习
关键词: LLM后训练, 分布式强化学习, KL正则化, 值函数引导, 数学推理

一句话总结

提出 Q♯,一种基于分布式 RL 的值函数方法用于 KL 正则化 LLM 后训练,通过学习参考策略下的累积奖励分布来计算最优软 Q 函数引导生成,在数学推理任务上实现更高准确率和更低 KL 散度,并证明了方差相关的 PAC 收敛界。

研究背景与动机

领域现状:RL 后训练是 LLM 对齐和推理的核心阶段。主流方法采用策略梯度(PPO、DPO、RLOO),通过 KL 散度约束防止偏离参考策略 \(\pi^{ref}\),但这些方法计算开销大(需全量反向传播)

现有痛点: - 策略方法在 star-graph 实验中暴露致命缺陷:预训练学到的走捷径(随机选首个节点、准确率 1/d)无法被 REINFORCE/RPO 修复——策略梯度在低概率路径上梯度也低,形成恶性循环 - 现有值方法(CD/VAS)使用未正则化\(Q^{\pi^{ref},0}\) 引导 \(\pi^{ref}\),忽略了 KL 项,无法保证收敛到最优策略 - CD 对 \(\eta\) 极其敏感——\(\eta^{-1}\) 增大后 KL 急剧膨胀,性能反降

核心矛盾:策略方法无法修复预训练捷径,现有值方法目标函数不正确

切入角度:在确定性 MDP(覆盖 LLM 自回归生成)中,\(Q^{\star,\eta}\) 可通过参考策略下累积奖励分布的泛函直接计算,无需 TD 学习

核心 idea:学习 \(Z^\star\)\(\pi^{ref}\) 的累积奖励条件分布),通过简单泛函 \(Q^{\star,\eta} = \eta \ln \mathbb{E}_{z \sim Z^\star}[\exp(z/\eta)]\) 得到最优 Q 函数

方法详解

整体框架

Q♯ 是一个迭代式值函数学习算法:每轮 (1) 用当前引导策略 \(\pi^k\) roll-in 到时刻 \(h\);(2) 切换到 \(\pi^{ref}\) 完成剩余轨迹;(3) 记录各时刻的累积奖励加入聚合数据集;(4) 在聚合数据上最小化分布式损失更新 \(Z^\theta\)。推理时通过 \(\pi^{Z,\eta}(y|x) \propto \pi^{ref}(y|x) \cdot \mathbb{E}_{z \sim Z(x,y)}[\exp(z/\eta)]\) 引导生成。

关键设计

  1. 确定性 MDP 下的分布式简化

    • 功能:将最优 Q 函数计算简化为参考策略下的奖励分布学习
    • 核心思路:在确定性 MDP 中,Theorem 2.2 证明 \(Q_h^{\star,\eta}(x_h, y_h) = \eta \ln \mathbb{E}_{\pi^{ref}}[\exp(\eta^{-1} \sum_{t \geq h} r_t) | x_h, y_h]\)。对稀疏奖励(如数学题正确性判定),进一步简化为 \(\eta \ln \mathbb{E}_{\pi^{ref}}[\exp(\eta^{-1} r(x_H, y_H)) | x_h, y_h]\)
    • 设计动机:避免了 TD 学习的所有痛点——无 bootstrapping、无 changing targets、无 distributional Bellman 方程的非收缩性问题。变成了标准的固定目标监督学习
  2. 分布式监督学习

    • 功能:用 MLE 拟合 \(Z^\star\) 的条件分布
    • 核心思路:二值奖励用 binary cross-entropy 拟合 Bernoulli 分布;任意奖励用直方图离散化 + MLE
    • 设计动机:分布式 RL 在表征学习、方差减少和二阶界方面都有优势
  3. DAgger 式迭代数据收集

    • 功能:每轮用当前引导策略 roll-in、参考策略 roll-out 收集数据,解决分布偏移
    • 核心思路:随机切换时刻 \(h \sim [H]\),用 \(\pi^k\) roll-in \(h-1\) 步,用 \(\pi^{ref}\) roll-out,记录 \((x_t, y_t, R_t)\) 加入聚合数据集
    • 设计动机:CD/VAS 仅在 \(\pi^{ref}\) 数据上离线训练,推理时分布偏移导致估计不准确
  4. 多 η 推理

    • 功能:一个训练好的 \(Z^\theta\) 支持任意 \(\eta\) 值的推理
    • 核心思路:\(Z^\theta\)\(\eta\) 无关,仅在引导公式中通过 \(\exp(z/\eta)\) 引入
    • 设计动机:部署时可灵活调节 KL 约束强度而无需重训

损失函数 / 训练策略

  • 损失:\(L_{bce}(r, \hat{p}) = -r \ln \hat{p} - (1-r) \ln(1 - \hat{p})\)(二值奖励)
  • 值模型:Llama 3.2 1B 参数化,引导 8B/70B 的 \(\pi^{ref}\)
  • 默认 \(\eta = 0.1\),2 轮迭代即收敛
  • V-type 参数化(预测 \(Q^{\star,\eta}(x, \hat{y})\))优于 Q-type,因参数更少

实验关键数据

主实验 — Star-Graph(修复预训练捷径)

方法 G(5,5) 准确率 G(2,20) 准确率 能否修复?
π_ref 20% (=1/d) 50% (=1/d) -
REINFORCE 20% 50%
DPO ~0% (崩溃) ~0% ✗ (更差)
Q♯ ~100% ~100%

主实验 — 数学推理 (Llama 3.1 8B → GSM8K/MATH)

方法 GSM8K pass@1↑ GSM8K KL↓ MATH pass@1↑ MATH KL↓
π_ref 82.9 - 43.9 -
CD 84.5 7.43 45.3 26.8
Q♯ 85.1 3.67 46.7 8.69

消融实验

配置 GSM8K val MATH val 说明
Q♯ 完整 最优 最优 V-type, 分布式, 2轮, prefix
Q-type 参数化 -1~2% -1% 参数更多,效率低
MSE 回归 (非分布式) -2~3% -2% 忽略了分布信息
1 轮 (无 DAgger) -1% -1% 分布偏移影响
无 prefix 扩展 -3~4% -2% 数据量不足

关键发现

  • 1B 值模型引导 70B LLM:1B Q♯ 让 70B Llama3.1 的 MATH pass@1 提升 2.5%、maj1@8 提升 3.5%
  • Q♯ 作为奖励模型做 Best-of-8,比 pass@1 提升 10%+
  • Q♯ 在准确率-KL 平面上严格 Pareto 支配 CD——更高准确率+更低 KL
  • 8B π_ref + 1B Q♯ (共9B) 的 maj1@8 ≈ 70B π_ref 的 pass@1——9 倍参数效率

亮点与洞察

  • "学奖励分布而非 Q 值"核心创新——将 RL 化简为无 bootstrapping 的监督学习,理论上避免了深度 RL 的不稳定性。这是一种只在确定性 MDP 下成立的优雅简化
  • 值方法能修复预训练捷径而策略方法不能——策略梯度在低概率路径上梯度也低的恶性循环是根本原因。值方法直接评估路径价值,绕开了这个问题
  • 小模型引导大模型是极具实用价值的范式——"评估比生成容易",1B 评估器就能显著提升 70B 生成器
  • 分布式 RL 的方差相关界:当 \(\pi^{ref}\) 方差小(已经不错)时,Q♯ 收敛更快

局限与展望

  • 仅适用于确定性 MDP——覆盖 LLM 但不覆盖随机环境
  • 迭代数据收集增加训练时间(但值模型小于策略模型,实际开销可控)
  • 分布式参数化(Bernoulli/直方图)的表达能力有限——连续奖励可能需要更灵活模型
  • 仅在数学推理验证,对话对齐等非稀疏奖励场景待探索

相关工作与启发

  • vs PPO/DPO:策略方法修改 \(\pi^{ref}\) 权重,灵活但有捷径问题;Q♯ 保持 \(\pi^{ref}\) 不变,更稳定
  • vs CD/VAS:三维度比较——目标(\(Q^{\star,\eta}\) vs \(Q^{\pi^{ref},0}\))、训练(在线迭代 vs 离线一次)、损失(分布式 vs MSE)
  • vs DPO:DPO 有类似 softmax 形式但只在序列级(H=1)操作,实际中 partition function 不可计算

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 分布式监督学习替代 TD 的理论突破
  • 实验充分度: ⭐⭐⭐⭐⭐ star-graph + 多规模数学推理 + 大模型引导 + 详细消融
  • 写作质量: ⭐⭐⭐⭐⭐ 理论动机清晰,对比精确,实验周密
  • 价值: ⭐⭐⭐⭐⭐ 对 LLM 后训练有范式性意义

相关论文