跳转至

AutoMixAlign: Adaptive Data Mixing for Multi-Task Preference Optimization in LLMs

会议: ACL 2025
arXiv: 2506.00569
代码: 无
领域: 对齐RLHF
关键词: data mixing, multi-task DPO, minimax optimization, excess loss, preference learning

一句话总结

AutoMixAlign 提出了一种理论驱动的多任务偏好优化数据混合方法:先训练各任务的 specialist model 确定最优 loss 基线,再通过 minimax 优化自适应调整数据混合比例,优先处理 excess loss(与 specialist 的差距)最大的任务,在 helpfulness/harmlessness/reasoning 多任务 DPO 中平均提升 9.42%。

研究背景与动机

  1. 领域现状:LLM 对齐训练需要在多个任务(helpfulness、safety、coding、math)上同时表现良好,DPO 训练需要混合多个任务数据集
  2. 现有痛点
  3. 均匀混合数据会被大数据集主导,小任务被忽视
  4. 任务均等加权(按数据集大小归一化)对困难任务和简单任务一视同仁,导致资源浪费
  5. 确定最优混合比例通常需要大量消融实验,成本高昂
  6. 核心矛盾:静态数据混合无法适应训练过程中各任务学习难度的动态变化
  7. 本文要解决什么? 自动化确定多任务 DPO 的数据混合比例
  8. 切入角度:以 specialist model 的 loss 为参照基准,通过最大化"最差任务的 excess loss 最小化"来动态调整混合
  9. 核心idea一句话:先训 specialist 定下各任务的 loss 目标,再通过 minimax 优化让 generalist 追赶所有 specialist

方法详解

整体框架

AMA 分两阶段:(1) 分别在各任务数据上用 DPO 训练 specialist models \(\theta_1, \ldots, \theta_k\);(2) 训练 generalist model,通过 minimax 优化问题 \(\min_{\theta} \max_{i \in [k]} \frac{1}{|\mathcal{D}_i|}\sum_{z \in \mathcal{D}_i} \max\{\mathcal{L}(\theta, z) - \mathcal{L}(\theta_i, z), 0\}\) 来平衡所有任务性能。

关键设计

  1. Excess Loss(超额损失):
  2. 做什么:量化 generalist 与 specialist 在每个样本上的 DPO loss 差距
  3. 核心思路:\(\mathcal{E}(\theta, \theta_i, z) = \max\{\mathcal{L}(\theta, z) - \mathcal{L}(\theta_i, z), 0\}\),clipped at 0 避免过度优化已学好的任务
  4. 设计动机:直接优化 loss 的问题是不知道"多好算好"。Specialist loss 提供了每个任务的可达目标——当 generalist loss 已低于 specialist loss 时停止优化该任务,避免过拟合

  5. AMA-R: Reweighting 算法:

  6. 做什么:自适应调整各任务在目标函数中的权重 \(\alpha_i\)
  7. 核心思路:将 minimax 问题转为 \(\min_\theta \max_{\alpha \in \Delta^k} \sum_i \alpha_i \cdot \text{avg-excess-loss}_i\)。交替执行指数梯度上升更新 \(\alpha\)(增大 excess loss 大的任务权重)和梯度下降更新 \(\theta\)
  8. 收敛性:凸情况下 \(O(1/\sqrt{T})\) 收敛率(继承自 Sagawa et al., 2019)

  9. AMA-S: Resampling 算法:

  10. 做什么:自适应调整从各任务数据集采样的概率
  11. 核心思路:用 EXP3 在线学习算法维护采样概率 \(\alpha\)\(\alpha_i = (1-c)q_i + c/k\)。根据各任务的 excess loss 通过指数更新调整 \(q_i\),excess loss 大的任务被采样更多
  12. 设计动机:AMA-R 中即使权重 \(\alpha_i \approx 0\),仍需对任务 \(i\) 采样并计算梯度(浪费计算);AMA-S 直接减少低 excess loss 任务的采样量
  13. 收敛性:将其建模为两人零和博弈,利用 EXP3 证明 \(O(1/\sqrt{T})\) 收敛率

损失函数 / 训练策略

  • Specialist 预计算:训练前预计算所有 \(\mathcal{L}(\theta_i, z)\) 并缓存,训练时无需额外前向传播
  • 平滑参数 \(c\):防止采样概率退化为 0,保持探索

实验关键数据

主实验

多任务 DPO(Helpfulness + Harmlessness + Reasoning):

方法 Helpfulness Harmlessness Reasoning Avg.
Uniform DPO 基准 基准 基准 基准
Task-normalized DPO +1-2% +0-1% -1-2% ~+0.5%
Model Merging 不稳定 不稳定 不稳定 低于AMA
AMA-R 显著提升 显著提升 显著提升 +7-9%
AMA-S 显著提升 显著提升 显著提升 +8-9.42%

AMA 在所有任务上均优于 uniform mixing 和 model merging,最大提升 9.42%。

消融实验

配置 效果 说明
w/o excess loss (用原始 loss) 下降 会过度优化已学好的任务
w/o specialist clipping 下降 无停止准则导致过拟合
AMA-R vs AMA-S AMA-S 略优 更高计算效率
不同 specialist 质量 影响小 specialist 只需"够好"

关键发现

  • Excess loss clipping 是关键——没有它 generalist 会在已学好的任务上继续优化,挤占困难任务的资源
  • AMA-S 比 AMA-R 计算效率更高(避免采样无用任务),且性能略优
  • Model merging(SLERP/TIES/DARE)在多任务对齐上表现不稳定,常牺牲某些任务

亮点与洞察

  • Specialist-as-target 的思路:用 specialist 的 loss 作为 generalist 的学习目标是非常优雅的设计——提供了"学到什么程度够了"的明确信号,避免了传统多任务学习中"不知道何时该停"的问题
  • 理论保证+实际有效:AMA-S 的 EXP3 收敛证明为自适应数据混合提供了理论基础,同时实验效果显著
  • 通用性:虽然本文聚焦 DPO,但 specialist + minimax excess loss 的框架可以迁移到 SFT、RLHF 等任何多任务学习场景

局限性 / 可改进方向

  • Specialist 训练成本:需要为每个任务训练一个 specialist model,\(k\) 大时成本线性增长。可探索只用少量训练步数得到"近似 specialist"
  • 凸收敛假设\(O(1/\sqrt{T})\) 保证仅在凸设置下成立,实际 LLM 训练是高度非凸的
  • 任务粒度定义:需要预先定义任务划分(哪些数据属于哪个任务),对任务边界模糊的场景可能不适用
  • 仅在 DPO 上验证:未验证与 PPO/GRPO 等 RL 方法的兼容性

相关工作与启发

  • vs Uniform/Heuristic Mixing: 传统方法静态或靠直觉;AMA 自适应且有理论保证
  • vs Model Merging (TIES/SLERP): 后处理合并方法对多任务对齐效果不稳定;AMA 在训练时动态平衡更可靠
  • vs Group DRO (Sagawa et al.): AMA-R 继承了 Group DRO 的 minimax 框架,但引入了 excess loss clipping 和 specialist 目标
  • 可以将这个框架用于 pre-training 的数据混合问题(替换 DPO loss 为 LM loss)

评分

  • 新颖性: ⭐⭐⭐⭐ specialist + minimax excess loss 是新颖组合,理论贡献扎实
  • 实验充分度: ⭐⭐⭐⭐ 多种设置和消融,但实验主要在中等规模模型上
  • 写作质量: ⭐⭐⭐⭐⭐ 数学推导严谨,问题动机清晰
  • 价值: ⭐⭐⭐⭐ 多任务数据混合是实际部署中的核心问题,AMA 提供了自动化方案