跳转至

AdaSTaR: Adaptive Data Sampling for Training Self-Taught Reasoners

会议: NeurIPS 2025
arXiv: 2505.16322
代码: GitHub
领域: LLM推理 / LLM效率
关键词: 自我改进推理, STaR, 自适应采样, 课程学习, 数据效率

一句话总结

发现 STaR(自我教学推理器)的随机数据采样导致观测训练频率严重不平衡(简单题过度训练、难题训练不足),提出 AdaSTaR——通过自适应多样性采样(优先欠训练样本)和自适应课程采样(根据模型强度调节难度),在 6 个基准上全部取得最高准确率同时减少 58.6% 训练 FLOPs。

研究背景与动机

  1. 领域现状:STaR (Self-Taught Reasoner) / RFT (Rejection sampling Fine-Tuning) 是 LLM 自我改进推理能力的核心训练范式——模型生成 CoT,验证正确答案后微调。被 DeepSeek-R1, Kimi k1.5 等前沿模型采用。
  2. 现有痛点:STaR 使用随机观测采样,导致:(a) 简单题被反复训练(10-13次),难题训练极少(1-2次)→ 计算浪费;(b) 72%的欠训练观测和91%的过训练观测在3轮迭代后保持不变 → 问题持久且不自愈。
  3. 核心矛盾:直接优先采样难题会增加 false positive(正确答案但错误 CoT)→ 需要平衡训练多样性与 CoT 质量。
  4. 切入角度:两个自适应原则——多样性(优先欠训练样本)+ 课程(模型弱时多采简单样本)。
  5. 核心 idea 一句话:用分层最小堆按"上次采样时间+难度"排序观测,同时用训练准确率 \(\alpha\) 作课程调节器自动平衡难度。

方法详解

整体框架

在 STaR 循环的数据采样步骤中插入自适应采样模块:维护每个观测的 \((\\tilde{t}_i, w_i)\) 统计 → 按分层最小堆排序 → 优先采样欠训练+困难观测 → 课程调节限制困难样本比例 → 正常训练。

关键设计

  1. 自适应多样性采样 (AdaD):
  2. 做什么:确保所有观测获得平衡的训练机会
  3. 核心数据结构:分层最小堆 HieMinHeap,排序键为 \((\tilde{t}_i, w_i)\)
    • 第一优先级:\(\tilde{t}_i\)(上次被采样的迭代),越早被采样越优先 → 促进多样性
    • 第二优先级:\(w_i\)(胜率统计),同迭代内越难(胜率低)越优先 → 聚焦难题
  4. 胜率统计:\(w_i = \frac{1}{K}\sum_{k=1}^K \mathbb{I}[y_i = \hat{y}_i]\)——在上次采样时的 K 次 CoT 采样中正确次数的比例
  5. 关键优势:\(w_i\) 的计算零额外开销,因为 K 次 CoT 采样本就是 STaR 的固有部分
  6. 非穷尽采样(Remark 1):while 循环在收集够 \(\beta^t\) 个正确样本后即停止,避免浪费

  7. 自适应课程采样 (AdaC):

  8. 做什么:模型弱时抑制过多困难样本,防止 false positive 上升
  9. 核心思路:用当前迭代的训练准确率 \(\alpha \in [0,1]\) 作为模型强度代理
  10. 实现:每迭代采样 \(m\) 个观测,但只更新前 \(\lfloor m \alpha^2 \rfloor\) 个的统计
  11. 效果:\(\alpha\) 低时,多数观测统计不更新 → 保留旧优先级 → 下次仍会被重新采样 → 实质上增加了简单样本的混入
  12. \(f(\alpha) = \alpha^2\):允许模型弱时多重复简单题,随模型变强迅速放开
  13. 零计算开销:\(\alpha\) 是训练步的副产品

训练策略

  • 基座模型:Llama 3.2 3B, Qwen 2.5 3B, Gemma 7B
  • 采用累积式 STaR (STaR-Acc):从上一轮模型继续训练
  • K=2(标准 CoT 采样数),公平对比
  • 评估:zero-shot greedy decoding

实验关键数据

主实验 (Llama 3.2 3B)

方法 ARC-C CQA CLadder ANLI GSM8K SVAMP Avg Acc. Avg FLOPs
STaR 基线 基线 基线 基线 - 基线 基线 基线
STaR-Acc 较好 较好 较好 较好 - 较好 较好 较多
B-STaR* - 很多
AdaSTaR 最佳 最佳 最佳 最佳 最佳 最佳 6/6 最佳 -58.6%

消融实验

配置 效果 说明
STaR-Acc(基线) 基线 随机采样
AdaD(仅多样性) +准确率 but ↑false positive 优先难题但 CoT 质量下降
AdaSTaR (AdaD + AdaC) 最佳 课程调节避免了 false positive
\(f(\alpha) = \alpha\) 接近最佳但略差 \(\alpha^2\) 更保守更好

关键发现

  • 6/6 基准全部最佳:AdaSTaR 在所有测试数据集上达到最高准确率
  • 58.6% FLOPs 节省:相比最强准确率基线,计算量减少近 60%
  • 训练多样性的量化影响:仅 AdaD 可增加 9% 的 false positive,加上 AdaC 后恢复
  • 泛化性强:在 Llama、Qwen、Gemma 三个模型族上一致有效
  • 效率+效果双赢:不是以性能换效率,而是同时提升两者

亮点与洞察

  • 零开销的难度估计:利用 STaR 固有的 K 次 CoT 采样计算胜率作为难度估计,不增加任何前向传播——巧妙地复用了系统中已有的计算
  • 训练准确率作课程信号\(\alpha\) 是训练过程的免费副产品,用它调节采样难度是零成本的自适应课程学习
  • 分层最小堆的数据结构选择:将多样性(\(\tilde{t}_i\))和难度(\(w_i\))编码在分层堆中,O(log N) 的采样效率
  • 对 STaR 训练动态的深入分析:发现了训练频率不平衡的持久性问题(72%/91% 保持不变),这个 observation 对理解 STaR 系统很有价值

局限性 / 可改进方向

  • 仅 outcome verification:只检查最终答案正确性,未使用 PRM(process reward model)
  • \(\alpha^2\) 的选择:课程函数 \(f(\alpha) = \alpha^2\) 是人工选择,可能存在更优形式
  • 未与 RL-based 方法比较:AdaSTaR 关注 SFT/STaR 管线,未与 GRPO 等 RL 方法直接对比
  • 改进方向:(1) 结合 PRM 做更精确的假阳性过滤;(2) 学习 \(f(\alpha)\) 而非手动选择;(3) 将自适应采样思路迁移到 RL-based 推理训练

相关工作与启发

  • vs STaR:AdaSTaR 是 STaR 的采样增强版,核心创新在于自适应数据采样策略
  • vs ReSTEM:ReSTEM 也注意到过度/不足训练问题,用截断阈值解决;AdaSTaR 的分层堆+课程方案更优
  • vs B-STaR:B-STaR 用 PRM 做更精细的验证但计算量巨大;AdaSTaR 不需要额外的 reward model
  • vs 课程学习:传统课程学习需要预定义难度度量,AdaSTaR 用胜率自然估计难度

评分

  • 新颖性: ⭐⭐⭐⭐ 将自适应采样和课程学习融入 STaR,零开销难度估计设计巧妙
  • 实验充分度: ⭐⭐⭐⭐⭐ 6 基准 × 3 模型族,大量 baseline,FLOPs 和准确率双指标,充分消融
  • 写作质量: ⭐⭐⭐⭐⭐ 动机分析深入(训练不平衡的量化+持久性),方法描述精确(算法伪代码+复杂度分析)
  • 价值: ⭐⭐⭐⭐⭐ 对广泛使用的 STaR/RFT 训练范式的实用改进,58.6% 效率提升意义重大