跳转至

Inference-Time Chain-of-Thought Pruning with Latent Informativeness Signals

会议: NeurIPS 2025
arXiv: 2511.00699
代码: 未开源
领域: LLM推理
关键词: inference-time scaling, chain-of-thought pruning, KL divergence, Best-of-N, reasoning efficiency

一句话总结

提出 KAPPA (KL-Adjusted Pruned Path Algorithm),利用 KL 散度、置信度和熵三个无需额外训练的信号对 Best-of-N 采样的推理分支进行渐进式剪枝,在保持准确率的同时实现最高 60% 峰值内存和 90% token 生成量的削减。

研究背景与动机

  1. 领域现状: LLM 通过 Chain-of-Thought (CoT) 和 Best-of-N (BoN) 采样提升推理准确率——采样 N 条推理路径并选最优。这是推理时缩放 (inference-time scaling) 的核心范式。
  2. 现有痛点: 标准 BoN 需要完整生成所有 N 条路径,计算和内存开销随 N 线性增长。已有方法 ST-BoN 通过一致性启发式截断不好的分支,但其一致性准则不直接评估分支质量。DeepConf 使用置信度加权投票但仍需多条路径跑完。
  3. 核心矛盾: 推理时缩放的效果依赖于采样更多路径(N 越大越好),但完整生成所有路径的开销限制了 N 的实际可用值——如何在不牺牲准确率的前提下大幅减少冗余计算?
  4. 本文解决什么: 设计一个无需训练的、基于信息论信号的推理分支剪枝算法,在推理过程中渐进式地淘汰低质量分支。
  5. 切入角度: 将 KL 散度作为分支"信息量"的自监督信号——偏离无条件分布越多的分支越有信息量——结合置信度和熵做综合评分。
  6. 核心idea: 不需要外部奖励模型,利用模型自身的 logits 分布特征就能判断哪些推理路径值得继续。

方法详解

整体框架

KAPPA 分为三个阶段:Draft(探索)→ Scoring & Gating(评分与门控剪枝)→ Continuation(利用)。

关键设计

  1. Draft Phase(草稿阶段)
  2. 做什么: 并行生成 N 条推理分支的前缀,直到截断点 \(c\)(所有分支两两不一致的最早时刻)
  3. 核心思路: 充分探索,确保分支间有足够的多样性再开始评估
  4. 设计动机: 过早剪枝会错杀有潜力的分支,需要先让分支"展开"

  5. Scoring & Gating Phase(评分与门控阶段)

  6. 做什么: 在 \([c, c+\tau)\) 的窗口内,每步对存活分支计算综合得分并渐进剪枝
  7. 核心思路: 三信号融合评分
    • KL 散度: \(D_{KL}(p_t^i \| q)\),计算当前分支 logits 与无条件(BOS token)logits 的 KL 散度,衡量信息增益
    • 置信度: \(C_t^i = \max_v p_t^i(v)\),top-1 token 概率,反映模型对当前预测的确信程度
    • : \(H_t^i\),logits 分布的不确定性
  8. 分数计算: \(s_t^i = w_{KL} \cdot \hat{EMA}_t^i + w_C \cdot \hat{C}_t^i + w_H \cdot \hat{H}_t^i\),权重 \((0.7, 0.2, 0.1)\)
  9. 稳定化: 用 Median-of-Means (MoM, 4 buckets) + 指数移动平均 (EMA, \(\alpha=0.5\)) 平滑信号,z-score 归一化并裁剪到 \([-3, 3]\)
  10. 剪枝调度: 线性调度,每步淘汰得分最低的分支,\(\tau\) 步后恰好剩 1 条
  11. 设计动机: KL 散度作为无需训练的分支质量信号,避免外部奖励模型的额外开销

  12. Continuation Phase(续写阶段)

  13. 做什么: 对唯一存活的分支继续自回归解码直到 EOS
  14. 核心思路: 剩余计算全部集中在最优分支上

实验关键数据

主实验(准确率 vs 效率)

在 GSM8K 和 MATH500 上,N=5/10/20:

模型 数据集 方法 N=20 准确率 N=20 峰值内存 Token 削减
DeepSeek-R1-1.5B MATH500 BoN ~70% 16240 MB -
DeepSeek-R1-1.5B MATH500 KAPPA 72.2% (+1-2%) 6495 MB ~90%
Qwen2.5-7B GSM8K BoN baseline baseline -
Qwen2.5-7B GSM8K KAPPA ≈baseline ~40%↓ ~65%

核心效率指标

  • 峰值内存削减: 4%~60%(取决于模型和 N)
  • Token 生成量削减: 65%~90%(相对于 BoN)
  • 最大差异: DeepSeek-R1-1.5B, MATH500, N=20 → KAPPA 仅用 2113 tokens vs BoN 的 20053 tokens (89.5% 削减)

消融/关键发现

  • 小模型受益更大: DeepSeek-R1-1.5B 上 KAPPA 持续提升 1-2% 准确率,因为小模型的低质分支更多,剪枝效果更明显
  • 大模型过度剪枝风险: Qwen2.5-7B 上准确率提升不一致,因为大模型的分支整体质量较高,线性剪枝可能错杀有潜力的分支
  • 超参数: KL 权重 0.7 最重要,EMA rate 0.5,MoM window 16,bucket 4

亮点与洞察

  • 无训练的分支质量信号: 直接利用模型 logits 计算 KL/置信度/熵,无需外部奖励模型或额外训练,是一个轻量且通用的方案
  • 渐进式剪枝优于一次性截断: 比 ST-BoN 的一次性截断更精细,逐步淘汰低分分支让评分信号有更多时间积累
  • 信息论视角的路径评估: 用 KL 散度衡量"分支偏离无条件分布的程度"作为信息量代理,直觉上有说服力——条件化的推理应该偏离先验越多越有信息量
  • "思考越多不一定越好"范式下的实用方案: 直接解决 BoN 的 N 缩放瓶颈

局限性 / 可改进方向

  • 仅测试两个模型两个数据集: 泛化性有待验证(缺少代码推理、常识推理等任务)
  • max_new_tokens 限制为 1024: DeepSeek-R1 在 MATH500 上经常需要超过 1024 tokens,截断影响了结果
  • 线性剪枝调度可能不是最优: 大模型上过于激进,cosine schedule 或自适应调度可能更好
  • N 增大时准确率可能下降: 分支增多导致过度剪枝,评分信号的噪声被放大
  • KL 散度的参考分布选择: 使用 BOS token 的无条件 logits 作为参考是否最优?可探索其他参考分布

相关工作与启发

  • ST-BoN (Wang et al., 2025): 本文的直接前驱,基于一致性的一次截断 → KAPPA 改为基于信息论的渐进剪枝
  • DeepConf (Fu et al., 2025): 置信度加权投票仍需多路径完整生成 → KAPPA 早期就剪枝
  • INFORM (Zhou et al., 2024): 自适应决定采样路径数但不干预路径内部 → KAPPA 在路径内部做细粒度管控
  • ThinkPrune (Hou et al., 2025): 用 RL 训练模型产生更短推理链 → KAPPA 是 training-free 的替代方案
  • 启发:KL 散度信号也许可以用于 speculative decoding 中的 draft 质量评估

评分

  • 新颖性: ⭐⭐⭐ 组合 KL/置信度/熵做分支剪枝的想法有新意,但各组件单独来看都不新
  • 实验充分度: ⭐⭐⭐ 仅两个模型两个数据集,且有 token 长度限制,消融不够充分
  • 写作质量: ⭐⭐⭐⭐ 算法描述清晰,三阶段结构易于理解,图表展示直观
  • 价值: ⭐⭐⭐⭐ 推理时效率优化是高需求方向,training-free 的方案实用性强