跳转至

Mastering Massive Multi-Task Reinforcement Learning via Mixture-of-Expert Decision Transformer

会议: ICML2025
arXiv: 2505.24378
代码: 待确认
领域: reinforcement_learning
关键词: 多任务强化学习, Decision Transformer, Mixture-of-Experts, 任务可扩展性, 参数可扩展性

一句话总结

提出 M3DT 框架,将 MoE 引入 Decision Transformer 实现参数分离——通过任务分组让每个专家只学习一个小任务子集的特定知识,配合三阶段训练机制(骨干→专家→路由器)避免梯度冲突,增加专家数既扩展参数又降低任务负载,成功将离线多任务 RL 扩展到 160 个仿真控制任务。

研究背景与动机

领域现状:Decision Transformer (DT) 将离线 RL 建模为序列预测问题,高容量 Transformer 架构也被用于多任务 RL (MTRL)。受 LLM 在海量任务上泛化能力的启发,研究者期望训练能掌握大量多样任务的 RL 智能体。

现有痛点: - 任务可扩展性不足:大多数工作仅处理几十个任务(Atari 或 Meta-World),扩展后性能显著退化。Gato 覆盖 600+ 任务但控制任务表现差。 - 参数扩展效率低:简单增大 DT 模型尺寸会迅速触及性能天花板。在 20M 参数后继续增大几乎无改善——这与 NLP 中模型越大越好的 scaling law 截然不同。

核心矛盾:共享参数面临的梯度冲突随任务数增多而加剧,导致性能退化;而朴素增大共享参数无法缓解——因为根本问题不是参数容量不足,而是所有任务共享参数导致的优化冲突。

本文目标:如何在任务数量极大(160 个)时既保持每个任务的学习质量,又实现高效的参数扩展?

切入角度:通过系统的实证分析,作者发现两个关键洞察: - 性能退化在任务数较少(<40)时最剧烈,任务数大后退化趋于平缓 → 反过来看,将每个参数子集需学习的任务数减少到足够小即可显著提升性能 - 参数扩展在"同时减少任务负载"时效果最好 → 扩展参数的同时必须减少每个参数子集的任务数

核心 idea:用 MoE 实现"参数分离 + 任务分组",增加专家数量同时扩展参数和降低任务负载,再配合三阶段训练避免梯度干扰。

方法详解

整体框架

M3DT 的整体 pipeline 如下:

  • 输入:与 Prompt-DT 一致,轨迹序列 \((r^*_{1}, s^*_1, a^*_1, \ldots, \hat{r}_{t-K+1}, s_{t-K+1}, a_{t-K+1}, \ldots, \hat{r}_t, s_t, a_t)\),前缀为 trajectory prompt
  • 骨干网络:5.29M 参数的 Prompt-DT(6层、8头、256维),学习跨任务共享知识
  • MoE 层:每个 Transformer block 中 FFN 旁并行添加专家 FFN 和路由器 MLP,前馈结果为 \(f(x) = x + f_{\text{FFN}}(x) + f_{\text{MoE}}(x)\)
  • 路由器:5层 MLP,对隐藏状态做 softmax 得到各专家权重
  • 输出:预测动作 \(a_t\),损失为 MSE

关键设计

  1. MoE 增强的 DT 架构

    • 功能:在 DT 每个 Transformer block 的 FFN 旁并行添加多个专家 FFN,保留原始 FFN 以维护骨干知识
    • 核心思路:MoE 输出为 \(f_{\text{MoE}}(x) = \sum_{i=1}^N \text{Softmax}(f_{\theta_r}(x))_i \cdot f_{\theta_i}(x)\),各专家结构与原 FFN 相同。原始 FFN 保留不动,MoE 输出作为残差加入,这样骨干的共享知识不会被破坏
    • 设计动机:作者在附录 B.2 中发现,Transformer block 中 MLP 层的梯度冲突远比 Attention 层严重,且单纯扩大 MLP 效果有限。因此选择在 MLP/FFN 侧加入 MoE 进行参数分离,每个专家只处理少量任务,大幅缓解 FFN 内的梯度冲突
  2. 任务分组策略

    • 功能:将所有任务显式分为若干组,每组分配给一个专家独立训练
    • 核心思路:提供两种分组方式——(1)随机分组:将任务随机等分,简单有效;(2)基于梯度的分组:计算每个任务的 agreement vector \(A(\mathcal{T}_i) = g_i \odot \frac{1}{N}\sum_{i=1}^N g_i\)(任务梯度与平均梯度的逐元素乘积),反映梯度一致性,再用 K-means 聚类。实验中梯度分组比随机分组在 160 任务上高约 2% (77.89 vs 80.14)
    • 设计动机:标准 MoE 的 top-k 动态路由在 RL 场景下不稳定——实验显示 top-4 routing 无法随专家数扩展,性能随专家增加反而下降。显式分组避免了路由不稳定和负载不均衡问题,同时利用任务结构信息使组内梯度冲突最小化
  3. 三阶段训练机制

    • 功能:将训练分为三个阶段,依次独立优化骨干、专家、路由器
    • 核心思路:
      • 阶段 1 — 骨干训练(400k 步):在所有任务上训练 Prompt-DT 骨干,学习跨任务共享表征。关键:在梯度冲突达到峰值时提前停止(约 400k 步),避免参数过拟合到梯度主导的任务
      • 阶段 2 — 专家训练(200k 步):冻结骨干,每个专家独立用其任务子集训练。各专家可并行训练(~1.8h/专家)
      • 阶段 3 — 路由器训练(400k 步):冻结骨干和专家,在所有任务上训练路由器学习动态权重分配。总训练约 24.2h(RTX 4090)
    • 设计动机:端到端训练 MoE 在 RL 中效果很差——实验显示端到端训练的 MoE 模型梯度冲突甚至比纯 PromptDT 的 MLP 还严重。分阶段训练让每个模块在无干扰环境下学习,骨干的提前停止策略避免了梯度冲突峰值后的性能下降

损失函数 / 训练策略

  • 动作预测 MSE 损失:\(\mathcal{L}_{DT} = \mathbb{E}[\frac{1}{K}\sum_{m=t-K+1}^t (a_{i,m} - \pi(\tau_i^*, \tau_{i,m}))^2]\)
  • 三个阶段均使用 Adam 优化器,学习率 1e-4,batch size 16
  • 不使用额外的 MoE 负载均衡损失(因为显式分组已保证均衡)

实验关键数据

主实验:不同任务规模下的性能对比

在 Meta-World (50 任务) + DMControl (30 任务) + Mujoco Locomotion (80 任务) 共 160 任务上,统一归一化评分。

任务规模 方法 归一化分数 参数量
10 tasks PromptDT-Small ~88% 1.47M
10 tasks HarmoDT-Large ~89% 173.30M
10 tasks M3DT-Gradient 89.23% 47.87M
80 tasks PromptDT-Large ~73% 173.30M
80 tasks M3DT-Gradient 79.58% (+6.6%) 98.37M
160 tasks PromptDT-Large ~68% 173.30M
160 tasks HarmoDT-Large ~72% 173.30M
160 tasks M3DT-Gradient 80.14% (+7.5%) 174.12M

M3DT 在 160 任务上的分数甚至高于其他 baseline 在 80 任务上的分数。

消融实验

配置 归一化分数 (160 tasks) 说明
M3DT-Random 77.89 随机分组
M3DT-Gradient 80.14 梯度分组,最优
M3DT w/o 3-stage training ~68 端到端训练,退化到 PromptDT 水平
M3DT w/o grouping 67.34 所有专家在所有任务上联合训练,比 baseline 还差
M3DT-R w/o expert freezing 次优 第三阶段同时微调专家,知识被覆盖
M3DT-G w/o expert freezing 次优 同上

关键发现

  • MoE 结构本身不够:单纯加 MoE 端到端训练(w/o 3-stage)效果与 PromptDT 相当,梯度冲突甚至更严重。任务分组 + 三阶段训练才是核心
  • 专家数与性能正相关但有上限:8→40 专家性能持续提升(160 任务上提升 11.7%),但 40 专家后收益递减。上限由三方面决定:(1) 骨干共享知识有限;(2) 任务子集足够小后再拆收益递减;(3) 路由器在专家数多时分配难度增大
  • 骨干提前停止很关键:最优停止点恰好在梯度冲突达到峰值时(400k 步)。训练太少(200k)骨干知识不充分;训练太多(>400k)参数过拟合到主导任务,后续专家学习也受影响
  • Top-K 路由在 RL 中失败:Top-4 routing 无法随专家数扩展,验证了 RL 中 sparse gating 的不稳定性。M3DT 需要所有专家的加权组合
  • 随机分组已经足够好:M3DT-Random 已超越所有 baseline,说明"减少任务数"本身就是核心收益;梯度分组额外提升约 2%

亮点与洞察

  • "减少任务数"的逆向思维:从"任务增加→性能退化曲线"的形状中发现退化主要集中在小任务数阶段,反推出核心策略不是让模型更大,而是让每个参数子集看到更少的任务。这个洞察精辟且与直觉相反。
  • MoE 在 RL 中的关键适配:NLP 中 MoE 用动态 top-k 路由,但 RL 数据分布与 NLP 差异巨大,动态路由不稳定。M3DT 用显式任务分组 + 全专家加权替代 top-k sparse routing,是一个重要的领域适配范式。
  • 梯度冲突峰值作为骨干训练的停止信号:将梯度冲突动态(先升后平)与训练阶段对齐,在峰值处切换到分组训练,兼顾了共享知识学习和冲突规避。

局限与展望

  • 仅限仿真环境:所有实验在 Meta-World/DMC/Mujoco 仿真中进行,状态空间为低维连续向量(最大 39 维),未验证在高维视觉输入或真实机器人上的效果。
  • 推理成本随专家数线性增长:所有专家均参与前馈,不同于 NLP 中 sparse MoE 只激活部分专家。作者尝试了 top-k 但效果不好,设计适配三阶段训练的 sparse gating 是重要的工程方向。
  • 三阶段训练的超参数:各阶段训练轮次(400k/200k/400k)需要额外调优,且骨干停止点依赖对梯度冲突的监控。自适应阶段切换机制值得探索。
  • 任务分组依赖先验:梯度分组需要先训练骨干再计算 agreement vector,计算额外开销。对全新任务集合或持续到达的新任务,如何在线更新分组是开放问题。
  • 未验证泛化和持续学习:所有实验中训练和测试使用相同任务集。在 held-out 任务上的泛化能力、以及新任务到来时的持续学习能力未被探索。

相关工作与启发

  • vs Multi-Game DT (Lee et al., 2022):直接扩大模型参数处理多任务 Atari,作者证明这种朴素扩展在 20M 参数后即饱和。M3DT 通过 MoE 在等量参数下性能显著更优。
  • vs Gato (Reed et al., 2022):通用多模态智能体覆盖 600+ 任务但控制任务性能差。M3DT 聚焦控制任务场景,通过任务分组和专家专精化取得更好表现。两者思路互补——MoE 可以集成到 Gato 类框架中。
  • vs HarmoDT (Hu et al., 2024):HarmoDT 用 task-specific mask 屏蔽冲突参数来缓解梯度冲突,本质上是"共享参数 + 掩码"范式。M3DT 的"共享骨干 + 独立专家"是更彻底的参数分离方案,在 160 任务上优势明显 (+7.5%)。
  • vs Soft MoE in RL (Obando-Ceron et al., 2024):该工作发现 top-k MoE 在深度 RL 中难以扩展,与本文的 top-4 实验结论一致。M3DT 的三阶段训练 + 全专家加权是绕过这一障碍的有效方案。

评分

  • 新颖性: ⭐⭐⭐⭐ MoE 在 RL 中的应用有创新,核心贡献在于RL场景下的关键适配(显式分组+三阶段训练)和实证洞察(任务数vs性能曲线)
  • 实验充分度: ⭐⭐⭐⭐⭐ 覆盖 10-160 任务规模的系统实验,消融分析极为详尽(分组策略/专家数/骨干停止点/路由器设计/专家结构),domain-wise 分析消除了偏差
  • 写作质量: ⭐⭐⭐⭐ 从实证观察到方法设计的逻辑链清晰完整,Figure 2/3 对问题的可视化说服力强
  • 价值: ⭐⭐⭐⭐ 对多任务 RL 的可扩展性研究有重要推动,"参数分离 + 任务分组 + 分阶段训练"范式可推广到其他多任务场景

相关论文