LASeR: Learning to Adaptively Select Reward Models with Multi-Armed Bandits¶
会议: NeurIPS 2025
arXiv: 2410.01735
代码: https://github.com/duykhuongnguyen/LASeR-MAB
领域: 对齐RLHF
关键词: 奖励模型选择, 多臂老虎机, 迭代训练, DPO, 多RM对齐
一句话总结¶
将多个奖励模型(RM)的选择建模为上下文多臂老虎机(LinUCB)问题,在迭代 LLM 训练中自适应地为每个 batch 选择最合适的 RM,在推理、指令跟随和长上下文任务上以 2-3 倍效率优势全面超越 RM 集成和单 RM 基线。
研究背景与动机¶
- 领域现状:RLHF/DPO 迭代训练管线依赖 RM 评分 LLM 输出来构建偏好数据。RewardBench 上有大量可选 RM,但不同 RM 在不同任务/领域上的泛化性未知。
- 现有痛点:
- 单 RM 风险:一个 RM 可能不泛化到所有任务域,且长期使用易导致 reward hacking
- 多 RM 集成低效:需同时加载和推理多个大模型,计算成本成倍增加;且 RM 之间存在偏好冲突(Qwen vs OLMo 的 agreement 仅 0.43),简单集成会引入噪声
- 手动选择不可行:RM 组合的空间是指数级的,exhaustive search 不现实
- 核心矛盾:需要多个 RM 的多样性来避免单 RM 的局限,但多 RM 的计算开销和信号冲突又会降低训练效果。
- 本文要解决什么? 设计一个高效机制,在训练过程中自动为每个 batch 选择最适合的 RM。
- 切入角度:将 RM 选择类比为多臂老虎机(MAB)——每个 RM 是一个"臂",用 LinUCB 算法在 exploration(尝试不同 RM)和 exploitation(使用已知最优 RM)之间平衡。
- 核心 idea 一句话:每步训练只选一个 RM(节省计算),但用 MAB 算法保证选择的自适应性和全局最优性。
方法详解¶
整体框架¶
迭代训练流程:每轮 (1) 给一批 prompt 用 LLM 生成多个回复 → (2) MAB 的 LinUCB 根据 prompt 嵌入选择一个 RM → (3) 用选中的 RM 评分回复并构建偏好对 → (4) 用 DPO (+NLL) 损失训练 LLM → (5) 用训练损失的负值作为 MAB 奖励更新 LinUCB 参数。
关键设计¶
- LinUCB 上下文多臂老虎机做 RM 选择
- 做什么:每个训练 batch,根据 prompt 的嵌入表示选择一个 RM
- 核心思路:\(j = \arg\max_k (c(t)^\top \hat{\theta}_k + \alpha \sqrt{c(t)^\top A_k^{-1} c(t)})\),其中 \(c(t)\) 是 batch 中 prompt 的平均 last-token 嵌入,\(\hat{\theta}_k\) 是每个 RM 的学习权重,\(\alpha\) 控制探索程度。\(A_k\) 和 \(b_k\) 在每步后根据 MAB 奖励更新
-
设计动机:LinUCB 利用上下文信息使 RM 选择具有领域自适应性——数学题可能选特定 RM,创意写作选另一个。同时 UCB 的探索项保证不会过早锁定一个 RM
-
MAB 奖励设计:负训练损失
- 做什么:用训练后的 DPO 损失的负值作为 MAB 对所选 RM 的奖励信号
- 核心思路:\(-\hat{\mathcal{L}}^m(t)\),更低的 DPO 损失意味着模型更清楚地学会了区分选中 RM 的偏好和非偏好回复 → 说明这个 RM 提供了更有信息量的偏好信号
-
设计动机:不需要额外的评估数据或人工标注就能判断 RM 是否适合——直接用训练信号反馈
-
每步只选一个 RM
- 做什么:每个 mini-batch 只加载和使用一个 RM,而非同时使用多个
-
设计动机:直接避免了多 RM 的计算开销和信号冲突问题。实验证明这比集成所有 RM 分数效果更好且更快
-
Best-of-N 推理变体
- 做什么:对于不适合微调的场景(如长上下文任务),用 MAB 学习 RM 选择策略后在 best-of-N 采样时做 RM 选择
- 设计动机:扩展 LASeR 的适用范围,不仅限于训练阶段
损失函数 / 训练策略¶
- 推理任务:\(\mathcal{L} = \mathcal{L}_{\text{DPO}} + \mathcal{L}_{\text{NLL}}\)(DPO + 正样本的 NLL 正则化)
- 指令跟随任务:\(\mathcal{L} = \mathcal{L}_{\text{DPO}}\)
- 每个 prompt 采样 30 个回复,构建 10 对偏好对
- 使用 LoRA 微调,温度 0.8
实验关键数据¶
主实验(推理任务,Llama-3-8B)¶
| 方法 | StrategyQA | GSM8K | MMLU | 平均 |
|---|---|---|---|---|
| SFT baseline | 80.41 | 69.43 | 65.66 | 71.83 |
| Best RM | 84.29 | 73.16 | 67.15 | 74.87 |
| Random RM | 84.37 | 71.99 | 67.85 | 74.74 |
| RM Score Ensemble | 82.96 | 70.94 | 67.04 | 73.65 |
| RM Agreement Ensemble | 84.03 | 73.85 | 68.35 | 75.41 |
| LASeR | 85.96 | 74.75 | 68.24 | 76.32 |
消融/效率分析¶
| 方法 | 平均 Acc | 训练时间(相对LASeR) |
|---|---|---|
| LASeR | 76.32 | 1× |
| Sequential RM | 74.95 | 3× |
| RM Score Ensemble | 73.65 | 2× |
| RM Online Ensemble | 74.05 | 2× |
关键发现¶
- LASeR 在所有三个领域全面领先:推理 (+2.67% vs 集成)、指令跟随 (72.69% 胜率 vs 集成)、长上下文 (+2.96 F1 vs 集成)
- 多 RM 集成的失败原因是信号冲突:Qwen vs OLMo 在 MMLU 上偏好 agreement 仅 0.43,简单集成会被冲突信号拖累
- 训练效率提升 2-3 倍:每步只用一个 RM,避免了多 RM 同时加载的 GPU 内存和计算开销
- MAB 的探索-利用平衡很关键:Random(纯探索)和 Best RM(纯利用)都不如 LASeR
- LASeR 对噪声 RM 具有鲁棒性:即使 RM 池中包含低质量 RM,MAB 会学会减少对它的选择
亮点与洞察¶
- MAB 框架是"选择 vs 集成"的优雅折中:不是"用所有 RM"也不是"只用一个 RM",而是"每次选最合适的一个"。这个思路可以迁移到任何需要从多个评判器/奖励信号中选择的场景(如多评委打分、多指标优化)
- 负训练损失作为 MAB 奖励不需要额外评估:避免了"需要验证集来评估 RM 质量"的鸡生蛋问题,直接用训练过程中的信号自循环
- 解释了为什么集成 RM 有时还不如单个 RM:RM 偏好冲突的量化分析(agreement F1)首次给出了系统性的解释
局限性 / 可改进方向¶
- 仅使用 4 个 7B 规模的 RM,更大规模或更多 RM 的情况未验证
- LinUCB 假设 MAB 奖励是上下文特征的线性函数,这对复杂任务可能过于简化
- RM 选择粒度是 batch 级别而非 instance 级别(batch size=1 时等价,但实验中 batch>1)
- 未与最新的 RL 方法(如 GRPO、DisCO)结合验证
- 长上下文任务只用 best-of-N 推理,未做端到端的长上下文微调
相关工作与启发¶
- vs RM Score Ensemble: 集成平均分容易被低质量 RM 拖累,LASeR 只选一个避免冲突
- vs WARM(权重平均 RM): WARM 在 RM 权重空间做平均,LASeR 在选择空间做自适应,两者角度不同
- vs DisCO: DisCO 从优化目标角度改进 GRPO,LASeR 从奖励信号质量角度改进训练,两者正交可组合
评分¶
- 新颖性: ⭐⭐⭐⭐ MAB 做 RM 选择的想法简单但有效,首次将 bandit 用于多 RM 对齐
- 实验充分度: ⭐⭐⭐⭐⭐ 推理 + 指令跟随 + 长上下文三大领域,8 个以上 baseline,消融充分
- 写作质量: ⭐⭐⭐⭐ 框架图清晰,实验分析到位
- 价值: ⭐⭐⭐⭐ 对多 RM 使用场景有直接实践价值,计算效率优势明显