跳转至

LASeR: Learning to Adaptively Select Reward Models with Multi-Armed Bandits

会议: NeurIPS 2025
arXiv: 2410.01735
代码: https://github.com/duykhuongnguyen/LASeR-MAB
领域: 对齐RLHF
关键词: 奖励模型选择, 多臂老虎机, 迭代训练, DPO, 多RM对齐

一句话总结

将多个奖励模型(RM)的选择建模为上下文多臂老虎机(LinUCB)问题,在迭代 LLM 训练中自适应地为每个 batch 选择最合适的 RM,在推理、指令跟随和长上下文任务上以 2-3 倍效率优势全面超越 RM 集成和单 RM 基线。

研究背景与动机

  1. 领域现状:RLHF/DPO 迭代训练管线依赖 RM 评分 LLM 输出来构建偏好数据。RewardBench 上有大量可选 RM,但不同 RM 在不同任务/领域上的泛化性未知。
  2. 现有痛点
  3. 单 RM 风险:一个 RM 可能不泛化到所有任务域,且长期使用易导致 reward hacking
  4. 多 RM 集成低效:需同时加载和推理多个大模型,计算成本成倍增加;且 RM 之间存在偏好冲突(Qwen vs OLMo 的 agreement 仅 0.43),简单集成会引入噪声
  5. 手动选择不可行:RM 组合的空间是指数级的,exhaustive search 不现实
  6. 核心矛盾:需要多个 RM 的多样性来避免单 RM 的局限,但多 RM 的计算开销和信号冲突又会降低训练效果。
  7. 本文要解决什么? 设计一个高效机制,在训练过程中自动为每个 batch 选择最适合的 RM。
  8. 切入角度:将 RM 选择类比为多臂老虎机(MAB)——每个 RM 是一个"臂",用 LinUCB 算法在 exploration(尝试不同 RM)和 exploitation(使用已知最优 RM)之间平衡。
  9. 核心 idea 一句话:每步训练只选一个 RM(节省计算),但用 MAB 算法保证选择的自适应性和全局最优性。

方法详解

整体框架

迭代训练流程:每轮 (1) 给一批 prompt 用 LLM 生成多个回复 → (2) MAB 的 LinUCB 根据 prompt 嵌入选择一个 RM → (3) 用选中的 RM 评分回复并构建偏好对 → (4) 用 DPO (+NLL) 损失训练 LLM → (5) 用训练损失的负值作为 MAB 奖励更新 LinUCB 参数。

关键设计

  1. LinUCB 上下文多臂老虎机做 RM 选择
  2. 做什么:每个训练 batch,根据 prompt 的嵌入表示选择一个 RM
  3. 核心思路:\(j = \arg\max_k (c(t)^\top \hat{\theta}_k + \alpha \sqrt{c(t)^\top A_k^{-1} c(t)})\),其中 \(c(t)\) 是 batch 中 prompt 的平均 last-token 嵌入,\(\hat{\theta}_k\) 是每个 RM 的学习权重,\(\alpha\) 控制探索程度。\(A_k\)\(b_k\) 在每步后根据 MAB 奖励更新
  4. 设计动机:LinUCB 利用上下文信息使 RM 选择具有领域自适应性——数学题可能选特定 RM,创意写作选另一个。同时 UCB 的探索项保证不会过早锁定一个 RM

  5. MAB 奖励设计:负训练损失

  6. 做什么:用训练后的 DPO 损失的负值作为 MAB 对所选 RM 的奖励信号
  7. 核心思路:\(-\hat{\mathcal{L}}^m(t)\),更低的 DPO 损失意味着模型更清楚地学会了区分选中 RM 的偏好和非偏好回复 → 说明这个 RM 提供了更有信息量的偏好信号
  8. 设计动机:不需要额外的评估数据或人工标注就能判断 RM 是否适合——直接用训练信号反馈

  9. 每步只选一个 RM

  10. 做什么:每个 mini-batch 只加载和使用一个 RM,而非同时使用多个
  11. 设计动机:直接避免了多 RM 的计算开销和信号冲突问题。实验证明这比集成所有 RM 分数效果更好且更快

  12. Best-of-N 推理变体

  13. 做什么:对于不适合微调的场景(如长上下文任务),用 MAB 学习 RM 选择策略后在 best-of-N 采样时做 RM 选择
  14. 设计动机:扩展 LASeR 的适用范围,不仅限于训练阶段

损失函数 / 训练策略

  • 推理任务:\(\mathcal{L} = \mathcal{L}_{\text{DPO}} + \mathcal{L}_{\text{NLL}}\)(DPO + 正样本的 NLL 正则化)
  • 指令跟随任务:\(\mathcal{L} = \mathcal{L}_{\text{DPO}}\)
  • 每个 prompt 采样 30 个回复,构建 10 对偏好对
  • 使用 LoRA 微调,温度 0.8

实验关键数据

主实验(推理任务,Llama-3-8B)

方法 StrategyQA GSM8K MMLU 平均
SFT baseline 80.41 69.43 65.66 71.83
Best RM 84.29 73.16 67.15 74.87
Random RM 84.37 71.99 67.85 74.74
RM Score Ensemble 82.96 70.94 67.04 73.65
RM Agreement Ensemble 84.03 73.85 68.35 75.41
LASeR 85.96 74.75 68.24 76.32

消融/效率分析

方法 平均 Acc 训练时间(相对LASeR)
LASeR 76.32
Sequential RM 74.95
RM Score Ensemble 73.65
RM Online Ensemble 74.05

关键发现

  • LASeR 在所有三个领域全面领先:推理 (+2.67% vs 集成)、指令跟随 (72.69% 胜率 vs 集成)、长上下文 (+2.96 F1 vs 集成)
  • 多 RM 集成的失败原因是信号冲突:Qwen vs OLMo 在 MMLU 上偏好 agreement 仅 0.43,简单集成会被冲突信号拖累
  • 训练效率提升 2-3 倍:每步只用一个 RM,避免了多 RM 同时加载的 GPU 内存和计算开销
  • MAB 的探索-利用平衡很关键:Random(纯探索)和 Best RM(纯利用)都不如 LASeR
  • LASeR 对噪声 RM 具有鲁棒性:即使 RM 池中包含低质量 RM,MAB 会学会减少对它的选择

亮点与洞察

  • MAB 框架是"选择 vs 集成"的优雅折中:不是"用所有 RM"也不是"只用一个 RM",而是"每次选最合适的一个"。这个思路可以迁移到任何需要从多个评判器/奖励信号中选择的场景(如多评委打分、多指标优化)
  • 负训练损失作为 MAB 奖励不需要额外评估:避免了"需要验证集来评估 RM 质量"的鸡生蛋问题,直接用训练过程中的信号自循环
  • 解释了为什么集成 RM 有时还不如单个 RM:RM 偏好冲突的量化分析(agreement F1)首次给出了系统性的解释

局限性 / 可改进方向

  • 仅使用 4 个 7B 规模的 RM,更大规模或更多 RM 的情况未验证
  • LinUCB 假设 MAB 奖励是上下文特征的线性函数,这对复杂任务可能过于简化
  • RM 选择粒度是 batch 级别而非 instance 级别(batch size=1 时等价,但实验中 batch>1)
  • 未与最新的 RL 方法(如 GRPO、DisCO)结合验证
  • 长上下文任务只用 best-of-N 推理,未做端到端的长上下文微调

相关工作与启发

  • vs RM Score Ensemble: 集成平均分容易被低质量 RM 拖累,LASeR 只选一个避免冲突
  • vs WARM(权重平均 RM): WARM 在 RM 权重空间做平均,LASeR 在选择空间做自适应,两者角度不同
  • vs DisCO: DisCO 从优化目标角度改进 GRPO,LASeR 从奖励信号质量角度改进训练,两者正交可组合

评分

  • 新颖性: ⭐⭐⭐⭐ MAB 做 RM 选择的想法简单但有效,首次将 bandit 用于多 RM 对齐
  • 实验充分度: ⭐⭐⭐⭐⭐ 推理 + 指令跟随 + 长上下文三大领域,8 个以上 baseline,消融充分
  • 写作质量: ⭐⭐⭐⭐ 框架图清晰,实验分析到位
  • 价值: ⭐⭐⭐⭐ 对多 RM 使用场景有直接实践价值,计算效率优势明显