跳转至

Hybrid Latent Reasoning via Reinforcement Learning

会议: NeurIPS 2025
arXiv: 2505.18454
代码: 无
领域: LLM推理
关键词: 潜在推理, 混合推理, 强化学习, 门控机制, 连续表示

一句话总结

HRPO 提出混合潜在推理策略优化:通过可学习的门控机制将前一步的隐藏状态表示逐步融入到采样的 token embedding 中,使 LLM 在推理阶段同时利用离散 token 和连续潜在表示,无需 CoT 标注即可通过 RL 训练,在知识密集型和 STEM 推理任务上均超越 PPO/GRPO 等基线。

研究背景与动机

  1. 领域现状:潜在推理(Latent Reasoning)作为 CoT 的替代方案正受到关注。Coconut 等方法将模型最后一层的隐藏状态作为"连续思维"反馈为下一步输入,在推理任务上取得了不错效果。但这类方法普遍依赖 CoT 轨迹进行训练。
  2. 现有痛点:(1) 现有潜在推理方法(如 Coconut、CODI)需要大量 CoT 标注数据进行多阶段训练,成本高且无法利用 LLM 本身的内在推理能力;(2) 直接将隐藏状态作为下一步的输入会破坏生成质量(重复、不连贯),因为隐藏状态和 token embedding 位于不同的表示空间。
  3. 核心矛盾:潜在推理需要连续表示以获得更丰富的信息,但 LLM 的自回归生成本质上是离散的。直接桥接两者会导致分布不匹配,破坏模型的生成能力。
  4. 本文要解决什么? 如何在不需要 CoT 标注的情况下,让现有 LLM 同时利用离散 token 和连续隐藏状态进行推理?
  5. 切入角度:设计一个门控机制,在 token embedding 中逐步混入隐藏状态信息,初始时几乎完全使用 token embedding(保持生成质量),训练过程中通过 RL 自动学习何时、多少地融入潜在表示。
  6. 核心idea一句话:用门控机制渐进融合离散 token 和连续隐藏状态,并通过 RL(而非 CoT 蒸馏)让 LLM 自主学习混合推理策略。

方法详解

整体框架

HRPO 在 LLM 的推理阶段(<think></think> 之间)引入混合输入:每个时间步的输入不再只是采样 token 的 embedding,而是 token embedding 和前一步隐藏状态的门控混合。答案部分仍使用标准自回归解码。训练采用 REINFORCE 风格的在线 RL,使用简单的结果奖励(正确=1,错误=0)。

关键设计

  1. 隐藏状态投影(Hidden State Projection):
  2. 做什么:将模型输出的隐藏状态 \(\hat{h}_t\) 映射回 embedding 空间
  3. 核心思路:使用 softmax 输出概率 \(p_{t+1} = \text{softmax}(\text{Head}(\hat{h}_t) / \tau)\) 对所有 token embedding 做加权求和:\(h_{t+1} = W_e^T \frac{p_{t+1}}{\|p_{t+1}\|}\)。温度 \(\tau\) 控制分布的锐利程度
  4. 设计动机:直接使用隐藏状态会导致分布不匹配和生成退化;通过概率加权插值,投影后的表示与模型原生 embedding 空间对齐,保持可微性

  5. 门控机制(Gating Mechanism):

  6. 做什么:控制离散 token embedding \(\hat{e}_{t+1}\) 和连续隐藏表示 \(h_{t+1}\) 的混合比例
  7. 核心思路:设计 reset gate \(r_t = \sigma(W_a \hat{e}_{t+1} + b_a)\)、input gate \(i_t = \sigma(W_x \hat{e}_{t+1} + b_x)\),以及衰减系数 \(a_t = \exp(-c \cdot \text{softplus}(\Lambda) \odot r_t)\),最终输入为 \(e_{t+1} = a_t \odot \hat{e}_{t+1} + \sqrt{1 - a_t^2} \odot (i_t \odot h_{t+1})\)\(\Lambda\) 是可学习参数
  8. 设计动机:初始化 \(a_t \to 1\) 使得训练开始时几乎完全使用 token embedding(保持 LLM 生成能力),随训练进行,门控逐步学习融入更多隐藏状态信息。这种渐进式设计避免了直接使用隐藏状态导致的生成崩溃

  9. 混合推理策略优化(HRPO):

  10. 做什么:基于 REINFORCE 风格的 RL 进行在线策略优化
  11. 核心思路:对每个问题生成 \(g\) 个混合 rollout(离散 token + 隐藏表示),用结果奖励(正确/错误)计算组内标准化优势 \(\hat{A}_i = \frac{r_i - \text{mean}([r_1,...,r_g])}{\text{std}([r_1,...,r_g])}\),策略梯度为 \(\nabla_\theta \mathcal{J} = \mathbb{E}[\frac{1}{g}\sum_i \frac{1}{|y_i|}\sum_t \nabla_\theta \log \pi_\theta(y_{i,t}|...) \hat{A}_{i,t}] - \beta \nabla_\theta D_{KL}[\pi_\theta \| \pi_{ref}]\)
  12. 设计动机:采样操作保留了随机性使得 RL rollout 可行;严格在线策略(每条轨迹只用一次)因为隐藏表示直接依赖当前参数 \(\theta\);轻量设计无需额外价值网络

损失函数 / 训练策略

使用 REINFORCE + KL 正则化,\(\beta = 0.005\)。不使用 PPO 的裁剪比率,直接用原始 log 概率,因为保守的学习率设置使得比率裁剪很少被触发。单 GPU 可运行。使用 LoRA (rank=32, \(\alpha\)=64) 进行高效微调。

实验关键数据

主实验

基础模型:Qwen2.5-1.5B-Instruct 和 Qwen2.5-3B-Instruct。

STEM 基准测试(Accuracy)

方法 GSM8k MATH MATH500 MMLU-ST ARC-C 平均
SFT (1.5B) 0.560 0.300 0.302 0.403 0.602 0.433
PPO (1.5B) 0.694 0.507 0.518 0.566 0.715 0.600
GRPO (1.5B) 0.711 0.502 0.524 0.562 0.737 0.607
HRPO (1.5B) 0.720 0.518 0.536 0.569 0.742 0.617
PPO (3B) 0.819 0.597 0.604 0.582 0.811 0.682
GRPO (3B) 0.834 0.602 0.604 0.601 0.814 0.691
HRPO (3B) 0.845 0.613 0.630 0.590 0.820 0.700

消融实验

配置 MATH (1.5B) 说明
Hidden States 直接使用 ~0 (崩溃) 隐藏状态与 embedding 空间不匹配
Interpolation (无门控) 先正常后崩溃 过多噪声导致训练不稳定
HRPO (门控) 0.518 渐进融合,稳定训练
Coconut 0.315 (GSM8k) 依赖 CoT 压缩,效果不佳
CODI 0.658 (GSM8k) CoT 自蒸馏,仍弱于 HRPO
HRPO 0.720 (GSM8k) 无需 CoT,RL 驱动

关键发现

  • HRPO 在 3B 模型上平均准确率 0.700,匹配甚至超越 7B 模型的表现(如 Qwen2.5-7B 平均 0.635)
  • 直接使用隐藏状态会导致奖励完全崩溃(约为 0),而纯插值方法虽然开始正常但最终也会崩溃——只有门控混合才能稳定训练
  • HRPO 训练的模型出现有趣的跨语言推理模式(如英中混合推理),暗示潜在表示能跨越语言边界
  • 隐藏比例(hidden ratio)在训练过程中稳步增长:模型主动学习利用更多潜在信息
  • 较小的 \(r_{min}\)(更大的初始隐藏比例)对知识任务更有利,而 STEM 任务在两极端值表现最佳
  • HRPO 训练后的模型生成更短的完成序列,因为隐藏状态有效编码了上下文信息

亮点与洞察

  • 渐进式门控设计非常精巧:从几乎纯 token embedding 开始,逐步融入隐藏状态——这种"先保持能力,再逐步增强"的策略避免了训练早期的灾难性退化,是一个可泛化的设计原则
  • 用 RL 替代 CoT 蒸馏训练潜在推理是重要范式转变:证明LLM 无需 CoT 标注就能自主发展潜在推理能力,大幅降低了训练成本
  • 跨语言推理的涌现行为:HRPO 训练后模型在推理时自然切换语言,表明潜在表示捕获了超越特定语言的推理模式,这一现象具有理论研究价值

局限性 / 可改进方向

  • 额外计算开销:混合推理引入门控计算和 embedding 投影,增加了前向传播成本
  • 严格在线策略限制了大规模训练效率:每条轨迹只能用一次,不能重用
  • 目前只在 1.5B 和 3B 模型上验证,更大规模模型的效果未知
  • 生成序列虽然更短但可能出现格式违规或重复循环
  • 可探索方向:off-policy 扩展、更大模型验证、格式奖励、与 CoT 的混合训练

相关工作与启发

  • vs Coconut: Coconut 使用多阶段 CoT 训练将 token 压缩为连续思维,HRPO 完全不需要 CoT 且通过 RL 训练。在 GSM8k 上 HRPO (0.720) vs Coconut (0.315) 差距巨大
  • vs CODI: CODI 用自蒸馏对齐显式和隐式推理 token,仍需 CoT 数据。HRPO 在两个数据集上一致优于 CODI
  • vs GRPO/PPO: HRPO 在所有基准上一致超越纯 RL 基线,表明混合连续表示确实提供了额外信息增益

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 首个用 RL 训练混合潜在推理的方法,门控设计优雅且动机充分
  • 实验充分度: ⭐⭐⭐⭐ 10 个基准 + 多模型 + 详细消融 + 显著性检验,但仅限小模型
  • 写作质量: ⭐⭐⭐⭐ 整体清晰,但部分符号定义可以更紧凑
  • 价值: ⭐⭐⭐⭐ 开辟了 RL-based 潜在推理新方向,但实际应用价值需在更大模型上验证