跳转至

Causally Reliable Concept Bottleneck Models

会议: NeurIPS 2025
arXiv: 2503.04363
代码: 提交时附带补充材料
领域: ai_safety
关键词: concept bottleneck model, causal reasoning, structural causal model, interpretability, fairness

一句话总结

提出 C2BM(Causally reliable Concept Bottleneck Models),将概念瓶颈(concept bottleneck)按照因果图结构化组织,通过结合观测数据与背景知识自动学习因果关系,在保持分类精度的同时显著提升因果可靠性、干预响应和公平性。

研究背景与动机

  1. 领域现状:Concept Bottleneck Models(CBMs)是可解释深度学习的代表范式,通过强制模型经由人类可理解的概念层进行推理来实现透明性。CBM 将预测分为两步:编码器将输入映射到概念,解码器基于概念预测任务标签。
  2. 现有痛点:现有 CBMs 采用二部图结构(bipartite),假设所有概念独立且直接影响输出。这一假设过于简化:(1) 忽略概念间的因果依赖关系,导致解释可能误导(如将肺癌归因于"咳嗽"和"吸烟",暗示减少咳嗽可降低癌症风险);(2) 概念独立假设阻止干预效果在相关概念间传播;(3) 仅学习统计相关性而非因果关系,容易受虚假相关影响。
  3. 核心矛盾:CBMs 本质上是关联模型(associative models),其决策过程反映的是数据中的统计相关性而非真实世界的因果机制。这导致它们无法支持因果推理、限制分布外泛化、阻碍公平性约束的实施。
  4. 本文要解决什么? (a) 如何将概念瓶颈按因果机制结构化?(b) 如何在无人类专家标注的情况下自动发现概念和因果图?(c) 如何在不牺牲精度的前提下提升因果可靠性?
  5. 切入角度:利用结构因果模型(SCM)将概念瓶颈组织为因果图,并通过 LLM + RAG 自动从非结构化背景知识中发现因果关系。
  6. 核心idea一句话:在概念瓶颈中嵌入因果图结构,用超网络自适应参数化结构方程,使模型推理过程与真实因果机制对齐。

方法详解

整体框架

C2BM 由三个核心模块组成:(1) 概念发现与标注——利用 LLM 发现相关概念并用 CLIP 标注数据集;(2) 因果图发现——结合因果发现算法(GES)与 LLM+RAG 查询来确定概念间的因果关系;(3) C2BM 模型本身——包含神经编码器 \(\mathbf{g}(\cdot)\) 和参数化 SCM \(\mathcal{M}_{\boldsymbol{\Theta}}\)

输入为原始数据 \(X\),编码器预测外生变量(高维嵌入)\(\mathcal{U} = \{U_i\}_{i=1}^C\),然后信息沿因果图从源节点流向汇节点。每个内生变量 \(V_i\) 的值由其因果父节点通过结构方程预测。

关键设计

  1. 因果瓶颈(Causal Bottleneck):
  2. 做什么:替代 CBM 的扁平二部图结构,用 DAG 组织概念间的因果关系
  3. 核心思路:将 C2BM 定义为 \(\langle \mathbf{g}, \mathcal{M}_{\boldsymbol{\Theta}} \rangle\),其中 SCM 为 \(\langle \mathcal{V}, \mathcal{U}, \mathcal{F}_{\boldsymbol{\Theta}}, P(\mathcal{U}|X) \rangle\)。信息先从编码器获得外生变量,再沿 DAG 拓扑序逐层计算每个概念
  4. 设计动机:因果结构使干预可以沿图传播,阻断虚假相关路径,支持公平性约束

  5. 自适应结构方程(Adaptive Structural Equations):

  6. 做什么:为每个概念学习可解释的因果机制
  7. 核心思路:结构方程取线性形式 \(V_i = \sum_{V_j \in \text{PA}_i} [\boldsymbol{\theta}_{f_i}]_j V_j\),但参数 \(\boldsymbol{\theta}_{f_i}\) 不是固定的,而是由超网络 \(\mathbf{r}_i(U_i)\) 对每个输入自适应预测,实现局部线性但全局非线性的表达能力
  8. 设计动机:线性结构方程保证可解释性(每个概念是父节点的加权线性组合),超网络保证表达能力(论文证明 C2BM 是通用近似器)

  9. 自动因果图构建流水线:

  10. 做什么:无需人类专家即可自动发现概念和因果图
  11. 核心思路:(a) 用 LLM 查询相关概念并用 CLIP 标注;(b) 用 GES 算法从观测数据学习 CPDAG(含无向边的部分图);(c) 用 LLM+RAG 查询背景知识来定向无向边并删除虚假边,每个查询重复 10 次取多数投票
  12. 设计动机:纯数据驱动的因果发现无法唯一确定 DAG;结合背景知识(如科学文献)可有效缩小候选图空间

损失函数 / 训练策略

训练目标为最大化经验对数似然:

\[\boldsymbol{\phi}^* = \arg\max_{\boldsymbol{\phi}} \sum_{\mathcal{D}} \sum_{i=1}^C \log P(V_i \mid \text{PA}_i, U_i; \mathbf{r}(\mathcal{E}, \mathcal{U})_i)\]

其中联合条件分布按因果图的马尔可夫条件分解为独立因子的乘积。编码器和超网络的参数端到端联合学习。

实验关键数据

主实验

数据集 OpaqNN CBM+lin CEM SCBM C2BM
Asia 71.0 71.2 71.1 70.7 71.4
Sachs 65.83 65.44 65.93 66.30 65.33
Hailfinder 72.0 72.2 71.5 73.4 74.1
cMNIST 91.24 93.92 93.72 94.02 94.18
CelebA 74.97 71.07 74.72 72.15 74.73
Pneumoth. 80.0 76.6 80.1 78.4 80.5

C2BM 在大多数数据集上与最强基线(OpaqNN、CEM)持平或更优,同时是唯一具备因果可靠性的模型。

消融实验 — 因果图质量

度量 扁平 CBM 仅因果发现(CD) CD + LLM
Hamming (Asia) 6.5 0.7 0.3
错误边数 (Sachs, 共17边) 23 17 7
错误边数 (Alarm, 共46边) 78 13 10

LLM 背景知识有效减少了因果图中的错误边,Sachs 上额外正确识别 10 条边。

关键发现

  • 干预实验:C2BM 在逐层干预概念时,下游准确率提升最快且幅度最大,因为因果结构使干预效果自然传播到子节点
  • 去偏(Debiasing):在 biased cMNIST 上,C2BM 通过因果图发现正确删除了 Color→Parity 的虚假边,干预 number 概念后准确率达 ~90%
  • 公平性:C2BM 是唯一能通过 do-intervention 将 CaCE(Attractive→Should be Hired)降至 0.0% 的模型,其他 CBM/CEM/SCBM 均无法完全阻断偏见路径

亮点与洞察

  • 因果结构 = 更好的干预:因果图使概念干预效果沿 DAG 传播,而非被二部图结构限制在单个概念上。1~2 个高层概念的干预即可显著提升所有下游节点精度
  • 超网络 + 线性方程 = 可解释的非线性:结构方程是线性的(可解释),但参数由超网络动态预测(表达能力强),实现了可解释性与表达能力的精巧平衡
  • LLM+RAG 自动构建因果图:将因果发现中"定向无向边"的难题转化为 LLM 背景知识查询问题,大幅降低对人类专家的依赖

局限性 / 可改进方向

  • 因果图质量依赖 LLM 背景知识的准确性,知识库偏差会传导到模型
  • 因果结构学习的可扩展性受限于现有因果发现算法(GES 在高维场景下计算开销大)
  • 编码器的分布外泛化能力不保证,可能影响 SCM 的 OOD 性能
  • 未考虑隐混杂因素(hidden confounders),假设因果图为 DAG
  • 实验数据集规模相对较小,大规模视觉任务(如 ImageNet)上的表现未验证

相关工作与启发

  • vs CBM/CEM: CBM/CEM 采用扁平二部图结构,所有概念独立且直连任务;C2BM 引入因果图结构,概念间有层次化因果关系
  • vs SCBM: SCBM 放松了概念独立假设但仅捕获关联关系;C2BM 捕获因果关系,干预效果更精准
  • vs DiConStruct: DiConStruct 是事后方法,可能与 DNN 输出不对齐且仅用观测数据;C2BM 是设计时方法,结合背景知识
  • 与可解释 AI 和公平性约束研究高度相关,因果瓶颈可直接用于实现算法公平

评分

  • 新颖性: ⭐⭐⭐⭐ 将因果推理融入概念瓶颈是自然但重要的新方向
  • 实验充分度: ⭐⭐⭐⭐ 涵盖精度/因果可靠性/干预/去偏/公平性五个维度,消融充分
  • 写作质量: ⭐⭐⭐⭐ 形式化严谨,流水线描述清晰
  • 价值: ⭐⭐⭐⭐⭐ 因果可靠的可解释模型对安全关键领域(医疗、法律)有重要实际意义