Hybrid Latent Reasoning via Reinforcement Learning¶
会议: NeurIPS 2025
arXiv: 2505.18454
代码: 无
领域: LLM推理
关键词: 潜在推理, 混合推理, 强化学习, 门控机制, 连续表示
一句话总结¶
HRPO 提出混合潜在推理策略优化:通过可学习的门控机制将前一步的隐藏状态表示逐步融入到采样的 token embedding 中,使 LLM 在推理阶段同时利用离散 token 和连续潜在表示,无需 CoT 标注即可通过 RL 训练,在知识密集型和 STEM 推理任务上均超越 PPO/GRPO 等基线。
研究背景与动机¶
- 领域现状:潜在推理(Latent Reasoning)作为 CoT 的替代方案正受到关注。Coconut 等方法将模型最后一层的隐藏状态作为"连续思维"反馈为下一步输入,在推理任务上取得了不错效果。但这类方法普遍依赖 CoT 轨迹进行训练。
- 现有痛点:(1) 现有潜在推理方法(如 Coconut、CODI)需要大量 CoT 标注数据进行多阶段训练,成本高且无法利用 LLM 本身的内在推理能力;(2) 直接将隐藏状态作为下一步的输入会破坏生成质量(重复、不连贯),因为隐藏状态和 token embedding 位于不同的表示空间。
- 核心矛盾:潜在推理需要连续表示以获得更丰富的信息,但 LLM 的自回归生成本质上是离散的。直接桥接两者会导致分布不匹配,破坏模型的生成能力。
- 本文要解决什么? 如何在不需要 CoT 标注的情况下,让现有 LLM 同时利用离散 token 和连续隐藏状态进行推理?
- 切入角度:设计一个门控机制,在 token embedding 中逐步混入隐藏状态信息,初始时几乎完全使用 token embedding(保持生成质量),训练过程中通过 RL 自动学习何时、多少地融入潜在表示。
- 核心idea一句话:用门控机制渐进融合离散 token 和连续隐藏状态,并通过 RL(而非 CoT 蒸馏)让 LLM 自主学习混合推理策略。
方法详解¶
整体框架¶
HRPO 在 LLM 的推理阶段(<think> 到 </think> 之间)引入混合输入:每个时间步的输入不再只是采样 token 的 embedding,而是 token embedding 和前一步隐藏状态的门控混合。答案部分仍使用标准自回归解码。训练采用 REINFORCE 风格的在线 RL,使用简单的结果奖励(正确=1,错误=0)。
关键设计¶
- 隐藏状态投影(Hidden State Projection):
- 做什么:将模型输出的隐藏状态 \(\hat{h}_t\) 映射回 embedding 空间
- 核心思路:使用 softmax 输出概率 \(p_{t+1} = \text{softmax}(\text{Head}(\hat{h}_t) / \tau)\) 对所有 token embedding 做加权求和:\(h_{t+1} = W_e^T \frac{p_{t+1}}{\|p_{t+1}\|}\)。温度 \(\tau\) 控制分布的锐利程度
-
设计动机:直接使用隐藏状态会导致分布不匹配和生成退化;通过概率加权插值,投影后的表示与模型原生 embedding 空间对齐,保持可微性
-
门控机制(Gating Mechanism):
- 做什么:控制离散 token embedding \(\hat{e}_{t+1}\) 和连续隐藏表示 \(h_{t+1}\) 的混合比例
- 核心思路:设计 reset gate \(r_t = \sigma(W_a \hat{e}_{t+1} + b_a)\)、input gate \(i_t = \sigma(W_x \hat{e}_{t+1} + b_x)\),以及衰减系数 \(a_t = \exp(-c \cdot \text{softplus}(\Lambda) \odot r_t)\),最终输入为 \(e_{t+1} = a_t \odot \hat{e}_{t+1} + \sqrt{1 - a_t^2} \odot (i_t \odot h_{t+1})\)。\(\Lambda\) 是可学习参数
-
设计动机:初始化 \(a_t \to 1\) 使得训练开始时几乎完全使用 token embedding(保持 LLM 生成能力),随训练进行,门控逐步学习融入更多隐藏状态信息。这种渐进式设计避免了直接使用隐藏状态导致的生成崩溃
-
混合推理策略优化(HRPO):
- 做什么:基于 REINFORCE 风格的 RL 进行在线策略优化
- 核心思路:对每个问题生成 \(g\) 个混合 rollout(离散 token + 隐藏表示),用结果奖励(正确/错误)计算组内标准化优势 \(\hat{A}_i = \frac{r_i - \text{mean}([r_1,...,r_g])}{\text{std}([r_1,...,r_g])}\),策略梯度为 \(\nabla_\theta \mathcal{J} = \mathbb{E}[\frac{1}{g}\sum_i \frac{1}{|y_i|}\sum_t \nabla_\theta \log \pi_\theta(y_{i,t}|...) \hat{A}_{i,t}] - \beta \nabla_\theta D_{KL}[\pi_\theta \| \pi_{ref}]\)
- 设计动机:采样操作保留了随机性使得 RL rollout 可行;严格在线策略(每条轨迹只用一次)因为隐藏表示直接依赖当前参数 \(\theta\);轻量设计无需额外价值网络
损失函数 / 训练策略¶
使用 REINFORCE + KL 正则化,\(\beta = 0.005\)。不使用 PPO 的裁剪比率,直接用原始 log 概率,因为保守的学习率设置使得比率裁剪很少被触发。单 GPU 可运行。使用 LoRA (rank=32, \(\alpha\)=64) 进行高效微调。
实验关键数据¶
主实验¶
基础模型:Qwen2.5-1.5B-Instruct 和 Qwen2.5-3B-Instruct。
STEM 基准测试(Accuracy)
| 方法 | GSM8k | MATH | MATH500 | MMLU-ST | ARC-C | 平均 |
|---|---|---|---|---|---|---|
| SFT (1.5B) | 0.560 | 0.300 | 0.302 | 0.403 | 0.602 | 0.433 |
| PPO (1.5B) | 0.694 | 0.507 | 0.518 | 0.566 | 0.715 | 0.600 |
| GRPO (1.5B) | 0.711 | 0.502 | 0.524 | 0.562 | 0.737 | 0.607 |
| HRPO (1.5B) | 0.720 | 0.518 | 0.536 | 0.569 | 0.742 | 0.617 |
| PPO (3B) | 0.819 | 0.597 | 0.604 | 0.582 | 0.811 | 0.682 |
| GRPO (3B) | 0.834 | 0.602 | 0.604 | 0.601 | 0.814 | 0.691 |
| HRPO (3B) | 0.845 | 0.613 | 0.630 | 0.590 | 0.820 | 0.700 |
消融实验¶
| 配置 | MATH (1.5B) | 说明 |
|---|---|---|
| Hidden States 直接使用 | ~0 (崩溃) | 隐藏状态与 embedding 空间不匹配 |
| Interpolation (无门控) | 先正常后崩溃 | 过多噪声导致训练不稳定 |
| HRPO (门控) | 0.518 | 渐进融合,稳定训练 |
| Coconut | 0.315 (GSM8k) | 依赖 CoT 压缩,效果不佳 |
| CODI | 0.658 (GSM8k) | CoT 自蒸馏,仍弱于 HRPO |
| HRPO | 0.720 (GSM8k) | 无需 CoT,RL 驱动 |
关键发现¶
- HRPO 在 3B 模型上平均准确率 0.700,匹配甚至超越 7B 模型的表现(如 Qwen2.5-7B 平均 0.635)
- 直接使用隐藏状态会导致奖励完全崩溃(约为 0),而纯插值方法虽然开始正常但最终也会崩溃——只有门控混合才能稳定训练
- HRPO 训练的模型出现有趣的跨语言推理模式(如英中混合推理),暗示潜在表示能跨越语言边界
- 隐藏比例(hidden ratio)在训练过程中稳步增长:模型主动学习利用更多潜在信息
- 较小的 \(r_{min}\)(更大的初始隐藏比例)对知识任务更有利,而 STEM 任务在两极端值表现最佳
- HRPO 训练后的模型生成更短的完成序列,因为隐藏状态有效编码了上下文信息
亮点与洞察¶
- 渐进式门控设计非常精巧:从几乎纯 token embedding 开始,逐步融入隐藏状态——这种"先保持能力,再逐步增强"的策略避免了训练早期的灾难性退化,是一个可泛化的设计原则
- 用 RL 替代 CoT 蒸馏训练潜在推理是重要范式转变:证明LLM 无需 CoT 标注就能自主发展潜在推理能力,大幅降低了训练成本
- 跨语言推理的涌现行为:HRPO 训练后模型在推理时自然切换语言,表明潜在表示捕获了超越特定语言的推理模式,这一现象具有理论研究价值
局限性 / 可改进方向¶
- 额外计算开销:混合推理引入门控计算和 embedding 投影,增加了前向传播成本
- 严格在线策略限制了大规模训练效率:每条轨迹只能用一次,不能重用
- 目前只在 1.5B 和 3B 模型上验证,更大规模模型的效果未知
- 生成序列虽然更短但可能出现格式违规或重复循环
- 可探索方向:off-policy 扩展、更大模型验证、格式奖励、与 CoT 的混合训练
相关工作与启发¶
- vs Coconut: Coconut 使用多阶段 CoT 训练将 token 压缩为连续思维,HRPO 完全不需要 CoT 且通过 RL 训练。在 GSM8k 上 HRPO (0.720) vs Coconut (0.315) 差距巨大
- vs CODI: CODI 用自蒸馏对齐显式和隐式推理 token,仍需 CoT 数据。HRPO 在两个数据集上一致优于 CODI
- vs GRPO/PPO: HRPO 在所有基准上一致超越纯 RL 基线,表明混合连续表示确实提供了额外信息增益
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ 首个用 RL 训练混合潜在推理的方法,门控设计优雅且动机充分
- 实验充分度: ⭐⭐⭐⭐ 10 个基准 + 多模型 + 详细消融 + 显著性检验,但仅限小模型
- 写作质量: ⭐⭐⭐⭐ 整体清晰,但部分符号定义可以更紧凑
- 价值: ⭐⭐⭐⭐ 开辟了 RL-based 潜在推理新方向,但实际应用价值需在更大模型上验证