跳转至

Confounding Robust Deep Reinforcement Learning: A Causal Approach

会议: NeurIPS 2025 arXiv: 2510.21110 代码: 有(附补充材料含游戏视频) 领域: 强化学习 / 因果推理 关键词: confounded MDP, off-policy learning, partial identification, causal DQN, Atari

一句话总结

基于部分辨识(partial identification)理论扩展 DQN,提出 Causal DQN 从含有未观测混淆因子的离线数据中学习鲁棒策略——通过优化最坏情况下的价值函数下界来获得安全策略,在 12 个混淆 Atari 游戏中一致性地超越标准 DQN。

研究背景与动机

  1. 领域现状:Deep RL(如 DQN)在高维状态空间中表现出色,但隐含地假设 No Unmeasured Confounder(NUC)——即行为策略的数据中不存在未观测的混淆因子。Off-policy 学习依赖这一假设将观测数据中的条件分布直接等同于因果转移分布。
  2. 现有痛点:当从演示者(demonstrator)的离线数据中学习时,学习者无法控制数据收集过程。如果演示者使用了学习者观测不到的信息(如在 Pong 游戏中演示者能看到对手位置但学习者只能看到局部画面),标准 DQN 无法区分因果效应和虚假相关性,导致学到无效策略。
  3. 核心矛盾:在混淆设置下,目标策略的效果通常不可辨识(not identifiable),传统方法无法从数据中唯一确定价值函数。但部分辨识方法可以推导出价值函数的信息性上下界。
  4. 本文要解决什么? 将部分辨识方法扩展到复杂高维领域(图像输入),构建对混淆鲁棒的深度 RL 算法。
  5. 切入角度:在混淆 MDP(CMDP)中,推导最优价值函数的因果 Bellman 最优方程(下界形式),然后用神经网络近似这个下界。
  6. 核心idea一句话:用因果 Bellman 方程将标准 DQN 的 Q 值更新替换为悲观下界更新——当观测数据中的动作与目标动作匹配时走标准路径,不匹配时用最坏情况奖励和最坏情况下一状态。

方法详解

整体框架

Causal DQN 遵循标准 DQN 的 experience replay 框架,但在 Q 值更新步骤中使用因果 Bellman 最优方程(Proposition 3.1)替代标准 Bellman 方程。学习者观察演示者的轨迹(演示者可能使用学习者看不到的信息),存入 replay buffer,用小批量梯度下降优化 Q-network 的下界。

关键设计

  1. 混淆 MDP(CMDP)形式化:
  2. 做什么:将未观测混淆因子显式建模到 MDP 中——\(\langle \mathcal{S}, \mathcal{X}, \mathcal{Y}, \mathcal{U}, \mathcal{F}, P \rangle\),其中 \(\mathcal{U}\) 是未观测噪声空间
  3. 核心思路:在因果图中,双向箭头 \(X_t \leftrightarrow Y_t\)\(X_t \leftrightarrow S_{t+1}\) 表示未观测混淆因子 \(U_t\) 同时影响动作、奖励和下一状态
  4. 设计动机:标准 MDP 假设 \(P(s'|s,a) = \mathcal{T}(s,a,s')\)(NUC 下直接从条件分布辨识转移分布),但在混淆下这一等式不成立

  5. 因果 Bellman 最优方程(Proposition 3.1):

  6. 做什么:推导最优 Q 值函数的可辨识下界 \(\underline{Q_*}(s,x)\)
  7. 核心思路:分两种情况更新:
    • 当观测动作 \(x_t = x\)(匹配):使用标准更新 \(y_t + \gamma \max_{x'} \underline{Q_*}(s_{t+1}, x')\)
    • 当观测动作 \(x_t \neq x\)(不匹配):使用最坏情况 \(a + \gamma \min_{s'} \max_{x'} \underline{Q_*}(s', x')\)\(a\) 是奖励下界)
  8. 设计动机:不匹配时,由于混淆因子的存在,无法知道执行目标动作 \(x\) 后的真实转移,只能用最坏情况估计保证安全性

  9. Q-Network 下界优化:

  10. 做什么:用神经网络 \(\underline{Q_*}(s,x;\theta)\) 近似下界函数
  11. 核心思路:损失函数 \(L_i(\theta_i) = \mathbb{E}_{s \sim \rho(\cdot)}[\sum_x (W_i(x) - \underline{Q_*}(s,x;\theta_i))^2]\),对所有动作同时更新(而非仅更新匹配动作),因为因果更新中不匹配动作也提供了关于下界的信息
  12. 设计动机:标准 DQN 只更新采样到的动作对应的 Q 值,但因果 Bellman 方程中所有动作的更新都依赖当前观测,因此需要同时优化

  13. 混淆 Atari 游戏设计:

  14. 做什么:对 12 个 Atari 游戏设计混淆版本——遮挡部分画面使学习者观测不到演示者使用的信息
  15. 核心思路:通过演示者的 saliency map 定位其依赖的视觉区域,遮挡这些区域作为未观测混淆因子
  16. 设计动机:构建受控实验环境,精确控制混淆因子的存在

训练策略

  • 1M 环境步训练,20 个并行环境
  • Batch size 512,replay buffer 100K
  • 使用 Double DQN 稳定学习
  • 最坏情况状态通过从 replay buffer 中随机采样估计

实验关键数据

12 个混淆 Atari 游戏对比

游戏 Demonstrator Conf.DQN Conf.LSTM-DQN Interv.DQN Causal DQN
Pong 21.0 -19.4 -20.5 -19.7 -1.8
Boxing 99.8 1.3 -2.6 -1.7 97.8
Gopher 6780 380 420 350 8140
ChopperCommand 4560 800 920 750 5280
Amidar 232.4 37.8 59.0 44.0 282.6

消融实验(标准化 return 聚合)

方法 归一化 Return
Conf. DQN ~0.05
Conf. LSTM-DQN ~0.08
Interv. DQN ~0.06
Causal DQN ~0.85

关键发现

  • 一致性优势:Causal DQN 在全部 12 个混淆游戏中都超越标准 DQN 变体,且差距巨大
  • 保守策略的涌现:Causal DQN 在信息缺失时学到了合理的保守行为——Pong 中只跟踪球(不追踪对手),Boxing 中采用"绳上战术"(rope-the-dope)防守
  • 超越演示者:在 Gopher 和 ChopperCommand 中,Causal DQN 甚至超越了拥有完整信息的演示者,因为保守策略在某些游戏中更有效
  • 标准 DQN 变体(包括 LSTM-DQN)完全无法处理混淆,几乎不收敛

亮点与洞察

  • 因果推理 × 深度 RL 的优雅结合:将部分辨识理论的严格数学框架无缝集成到 DQN 的实际算法中,证明了因果方法在高维领域的可行性
  • 保守策略的自然涌现:算法没有被显式编程为保守,但悲观下界优化自然导致了安全、合理的策略——这是一种"安全 RL"的副产品
  • 混淆 Atari 游戏作为基准:通过 saliency map 引导的画面遮挡构建受控混淆环境,这种方法可推广到其他混淆 RL 基准的构建
  • 所有动作同时更新的 insight:因果 Bellman 方程中不匹配动作也提供了信息,这与标准 off-policy RL 只更新匹配动作的做法形成鲜明对比

局限性 / 可改进方向

  • 最坏情况过于悲观:在某些游戏中,最坏情况假设可能导致过度保守,不能充分利用数据中虽不精确但仍有信息的信号
  • 最坏状态估计依赖 replay buffer 采样:可能不准确,尤其在状态空间大时
  • 仅验证 Atari 环境:更复杂的连续控制任务和现实世界场景未评估
  • 计算开销:对所有动作同时更新 + 最坏状态搜索增加了训练时间

相关工作与启发

  • vs 标准 DQN: 标准 DQN 假设 NUC,混淆下完全失败;Causal DQN 通过悲观下界保证鲁棒性
  • vs Zhang & Bareinboim (2025) 因果 Bellman 方程: 本文将其从表格设置扩展到深度网络实现,证明了在高维视觉输入下的可行性
  • vs 悲观离线 RL(如 CQL): CQL 对分布偏移悲观,本文对因果混淆悲观——两种悲观的来源和处理方式不同

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 因果部分辨识 × DQN 的首次高维实现
  • 实验充分度: ⭐⭐⭐⭐⭐ 12 个混淆 Atari 游戏 + 详细消融 + saliency 可视化 + 游戏视频
  • 写作质量: ⭐⭐⭐⭐⭐ 问题动机→理论→算法→实验的叙事非常流畅
  • 价值: ⭐⭐⭐⭐⭐ 对因果 RL 和安全 RL 领域有重大推动