跳转至

AutoJudge: Judge Decoding Without Manual Annotation

会议: NeurIPS 2025
arXiv: 2504.20039
代码: https://github.com/garipovroma/autojudge
领域: 高效推理 / 投机解码
关键词: 投机解码, Judge Decoding, 重要token, 自动标注, 推理加速

一句话总结

AutoJudge 自动化了 Judge Decoding 中"重要 token"的标注——通过半贪心搜索替换不匹配 token 并检查答案是否改变来标注重要性,训练逻辑回归分类器预测 token 重要性,使投机解码每轮接受 40+ token(vs 标准 ~20),在 GSM8K 上加速 1.5× 且准确率损失 <1%。

研究背景与动机

  1. 领域现状:投机解码用小 draft 模型生成 token 序列,大 target 模型验证。标准方法严格验证每个 token——每轮只接受约 20 个 token。Judge Decoding 放松验证——不重要 token 允许偏差——但需要手动标注哪些 token "重要"。
  2. 现有痛点:(a) 手动标注重要 token 需要领域知识且不可扩展;(b) 全部 token 严格验证太保守——很多 token(如格式、换行)的偏差不影响最终答案;(c) 投机解码的加速比被低接受率限制。
  3. 核心矛盾:放松验证可以大幅提高接受率,但需要知道哪些 token 可以放松——这个判断之前需要人工标注。
  4. 本文要解决什么? 自动标注 token 重要性 + 训练分类器实时预测。
  5. 切入角度:如果替换某个 token 不改变最终答案,那它就是"不重要的"。半贪心搜索算法自动发现这些 token。
  6. 核心 idea 一句话:半贪心搜索替换 token → 检查答案是否变 → 自动标注重要/不重要 → 逻辑回归预测 → 不重要 token 跳过验证 → 40+ token/轮。

方法详解

整体框架

  1. 离线标注: 对训练数据,用 draft 生成 → 找不匹配 token → 逐个替换检查答案是否变 → 标注重要/不重要
  2. 训练分类器: 用 draft/target 隐藏状态拼接作为特征,逻辑回归二分类(阈值调到 ≥90% 召回)
  3. 在线推理: draft 生成 → 分类器预测重要性 → 仅对重要 token 做 target 验证 → 不重要 token 直接接受 → 大幅增加每轮 token 数

关键设计

  1. 半贪心搜索标注(Algorithm 1):
  2. 做什么:自动发现哪些 token 替换后不影响答案
  3. 核心思路:提取答案 \(\alpha\);生成 draft 响应 \(\tilde{y}\);对每个不匹配位置 \(t\):替换 target token 为 draft token → 重新运行模型 → 新答案等价?是→不重要;否→重要,保留 target 版本
  4. 设计动机:穷举搜索代价太高(\(2^n\) 组合),半贪心是实用的近似

  5. 隐藏状态分类器:

  6. 做什么:推理时实时预测 token 重要性
  7. 核心思路:输入 = draft 和 target 模型在不匹配 token 处的隐藏状态拼接;逻辑回归输出重要性概率;阈值调到 ≥90% 召回(宁可多验证也不漏掉重要 token)
  8. 设计动机:逻辑回归极轻量(毫秒级推理),隐藏状态包含 token 级别的语义信息

  9. 与 EAGLE-2 叠加:

  10. 做什么:在更快的 draft 模型上叠加 AutoJudge
  11. 核心思路:EAGLE-2 已经是更好的 draft 模型,AutoJudge 叠加后额外加速 1.01-1.20×
  12. 设计动机:两种技术正交——EAGLE 改善 draft 质量,AutoJudge 放松验证

损失函数 / 训练策略

  • 逻辑回归:标准交叉熵
  • 阈值选择:在验证集上调到 ≥90% 召回率
  • 精度:bfloat16 有 ~10% 嵌入方差,float32 更稳定

实验关键数据

主实验

任务 每轮接受 token 准确率 加速
GSM8K 0-shot (8B/70B) 40+ 92% (≤1% 损失) 1.5×
GSM8K 8-shot (8B/70B) 40+ 95.4% (<1% 损失)
LiveCodeBench 22+ ~2% 损失
vLLM 集成 1.5-2× (A100/H100)

消融实验

配置 效果
Draft+Target 隐藏状态 vs 仅 Draft 拼接更好
Token 嵌入 vs 前一 token Token 嵌入更好
规则方法(仅数学 token) 65-80%(明显弱于学习)
+ EAGLE-2 额外 1.01-1.20× 加速
GSM8K 分类器→LiveCodeBench 失败(需任务特定训练)

关键发现

  • 每轮接受 token 从 ~20 翻倍到 40+——直接转化为 1.5-2× 加速
  • 准确率损失极小(<1-2%)——"不重要"的 token 确实可以跳过
  • 任务特定——GSM8K 分类器不能迁移到 LiveCodeBench(需要重新训练)
  • 浮点精度影响大——bfloat16 的嵌入近似引入噪声

亮点与洞察

  • 自动化"什么是重要的"判断:从手动到自动的关键转变,使 Judge Decoding 实用化
  • 半贪心搜索是实用的近似:虽非最优但效果好且可计算
  • 隐藏状态包含足够的信号:简单的逻辑回归就能预测 token 重要性——说明 LLM 内部"知道"哪些 token 重要

局限性 / 可改进方向

  • 需要任务特定训练数据——不同任务的"重要 token"不同
  • 开放式任务(如创意写作)难以定义"答案等价"
  • 半贪心搜索是次优的——穷举树搜索可能更准但太慢
  • 浮点精度敏感性限制了部署灵活性

相关工作与启发

  • vs 标准投机解码: 标准方法每轮 ~20 token,AutoJudge 40+——接受率翻倍
  • vs Judge Decoding (手动): 手动标注不可扩展,AutoJudge 自动化
  • vs EAGLE/Medusa: 改善 draft 质量的方法,与 AutoJudge(放松验证)正交

评分

  • 新颖性: ⭐⭐⭐⭐ 自动化 token 重要性标注是实用创新
  • 实验充分度: ⭐⭐⭐⭐ GSM8K + LiveCodeBench + EAGLE 叠加 + vLLM 集成
  • 写作质量: ⭐⭐⭐⭐ 算法描述清晰
  • 价值: ⭐⭐⭐⭐ 使 Judge Decoding 从理论走向实践

方法补充说明

  • 半贪心搜索的复杂度:对长度 \(L\) 的序列有 \(M\) 个不匹配 token,搜索复杂度 \(O(M imes T_{gen})\)\(T_{gen}\) 是单次生成时间)。实际中 \(M pprox L/3\)(约 1/3 token 不匹配),单次搜索 ~1-2 分钟/样本
  • 分类器特征选择:拼接 draft 和 target 的隐藏状态(而非只用一个)效果更好——两个模型的不一致本身就是 token 重要性的强信号
  • 与 speculative decoding 理论的关系:标准 speculative decoding 保证与 target 分布完全一致(无损),AutoJudge 允许 <2% 的准确率损失换取 2× 加速——这是精度-速度的显式权衡