Causally Reliable Concept Bottleneck Models¶
会议: NeurIPS 2025
arXiv: 2503.04363
代码: 提交时附带补充材料
领域: ai_safety
关键词: concept bottleneck model, causal reasoning, structural causal model, interpretability, fairness
一句话总结¶
提出 C2BM(Causally reliable Concept Bottleneck Models),将概念瓶颈(concept bottleneck)按照因果图结构化组织,通过结合观测数据与背景知识自动学习因果关系,在保持分类精度的同时显著提升因果可靠性、干预响应和公平性。
研究背景与动机¶
- 领域现状:Concept Bottleneck Models(CBMs)是可解释深度学习的代表范式,通过强制模型经由人类可理解的概念层进行推理来实现透明性。CBM 将预测分为两步:编码器将输入映射到概念,解码器基于概念预测任务标签。
- 现有痛点:现有 CBMs 采用二部图结构(bipartite),假设所有概念独立且直接影响输出。这一假设过于简化:(1) 忽略概念间的因果依赖关系,导致解释可能误导(如将肺癌归因于"咳嗽"和"吸烟",暗示减少咳嗽可降低癌症风险);(2) 概念独立假设阻止干预效果在相关概念间传播;(3) 仅学习统计相关性而非因果关系,容易受虚假相关影响。
- 核心矛盾:CBMs 本质上是关联模型(associative models),其决策过程反映的是数据中的统计相关性而非真实世界的因果机制。这导致它们无法支持因果推理、限制分布外泛化、阻碍公平性约束的实施。
- 本文要解决什么? (a) 如何将概念瓶颈按因果机制结构化?(b) 如何在无人类专家标注的情况下自动发现概念和因果图?(c) 如何在不牺牲精度的前提下提升因果可靠性?
- 切入角度:利用结构因果模型(SCM)将概念瓶颈组织为因果图,并通过 LLM + RAG 自动从非结构化背景知识中发现因果关系。
- 核心idea一句话:在概念瓶颈中嵌入因果图结构,用超网络自适应参数化结构方程,使模型推理过程与真实因果机制对齐。
方法详解¶
整体框架¶
C2BM 由三个核心模块组成:(1) 概念发现与标注——利用 LLM 发现相关概念并用 CLIP 标注数据集;(2) 因果图发现——结合因果发现算法(GES)与 LLM+RAG 查询来确定概念间的因果关系;(3) C2BM 模型本身——包含神经编码器 \(\mathbf{g}(\cdot)\) 和参数化 SCM \(\mathcal{M}_{\boldsymbol{\Theta}}\)。
输入为原始数据 \(X\),编码器预测外生变量(高维嵌入)\(\mathcal{U} = \{U_i\}_{i=1}^C\),然后信息沿因果图从源节点流向汇节点。每个内生变量 \(V_i\) 的值由其因果父节点通过结构方程预测。
关键设计¶
- 因果瓶颈(Causal Bottleneck):
- 做什么:替代 CBM 的扁平二部图结构,用 DAG 组织概念间的因果关系
- 核心思路:将 C2BM 定义为 \(\langle \mathbf{g}, \mathcal{M}_{\boldsymbol{\Theta}} \rangle\),其中 SCM 为 \(\langle \mathcal{V}, \mathcal{U}, \mathcal{F}_{\boldsymbol{\Theta}}, P(\mathcal{U}|X) \rangle\)。信息先从编码器获得外生变量,再沿 DAG 拓扑序逐层计算每个概念
-
设计动机:因果结构使干预可以沿图传播,阻断虚假相关路径,支持公平性约束
-
自适应结构方程(Adaptive Structural Equations):
- 做什么:为每个概念学习可解释的因果机制
- 核心思路:结构方程取线性形式 \(V_i = \sum_{V_j \in \text{PA}_i} [\boldsymbol{\theta}_{f_i}]_j V_j\),但参数 \(\boldsymbol{\theta}_{f_i}\) 不是固定的,而是由超网络 \(\mathbf{r}_i(U_i)\) 对每个输入自适应预测,实现局部线性但全局非线性的表达能力
-
设计动机:线性结构方程保证可解释性(每个概念是父节点的加权线性组合),超网络保证表达能力(论文证明 C2BM 是通用近似器)
-
自动因果图构建流水线:
- 做什么:无需人类专家即可自动发现概念和因果图
- 核心思路:(a) 用 LLM 查询相关概念并用 CLIP 标注;(b) 用 GES 算法从观测数据学习 CPDAG(含无向边的部分图);(c) 用 LLM+RAG 查询背景知识来定向无向边并删除虚假边,每个查询重复 10 次取多数投票
- 设计动机:纯数据驱动的因果发现无法唯一确定 DAG;结合背景知识(如科学文献)可有效缩小候选图空间
损失函数 / 训练策略¶
训练目标为最大化经验对数似然:
其中联合条件分布按因果图的马尔可夫条件分解为独立因子的乘积。编码器和超网络的参数端到端联合学习。
实验关键数据¶
主实验¶
| 数据集 | OpaqNN | CBM+lin | CEM | SCBM | C2BM |
|---|---|---|---|---|---|
| Asia | 71.0 | 71.2 | 71.1 | 70.7 | 71.4 |
| Sachs | 65.83 | 65.44 | 65.93 | 66.30 | 65.33 |
| Hailfinder | 72.0 | 72.2 | 71.5 | 73.4 | 74.1 |
| cMNIST | 91.24 | 93.92 | 93.72 | 94.02 | 94.18 |
| CelebA | 74.97 | 71.07 | 74.72 | 72.15 | 74.73 |
| Pneumoth. | 80.0 | 76.6 | 80.1 | 78.4 | 80.5 |
C2BM 在大多数数据集上与最强基线(OpaqNN、CEM)持平或更优,同时是唯一具备因果可靠性的模型。
消融实验 — 因果图质量¶
| 度量 | 扁平 CBM | 仅因果发现(CD) | CD + LLM |
|---|---|---|---|
| Hamming (Asia) | 6.5 | 0.7 | 0.3 |
| 错误边数 (Sachs, 共17边) | 23 | 17 | 7 |
| 错误边数 (Alarm, 共46边) | 78 | 13 | 10 |
LLM 背景知识有效减少了因果图中的错误边,Sachs 上额外正确识别 10 条边。
关键发现¶
- 干预实验:C2BM 在逐层干预概念时,下游准确率提升最快且幅度最大,因为因果结构使干预效果自然传播到子节点
- 去偏(Debiasing):在 biased cMNIST 上,C2BM 通过因果图发现正确删除了 Color→Parity 的虚假边,干预 number 概念后准确率达 ~90%
- 公平性:C2BM 是唯一能通过 do-intervention 将 CaCE(Attractive→Should be Hired)降至 0.0% 的模型,其他 CBM/CEM/SCBM 均无法完全阻断偏见路径
亮点与洞察¶
- 因果结构 = 更好的干预:因果图使概念干预效果沿 DAG 传播,而非被二部图结构限制在单个概念上。1~2 个高层概念的干预即可显著提升所有下游节点精度
- 超网络 + 线性方程 = 可解释的非线性:结构方程是线性的(可解释),但参数由超网络动态预测(表达能力强),实现了可解释性与表达能力的精巧平衡
- LLM+RAG 自动构建因果图:将因果发现中"定向无向边"的难题转化为 LLM 背景知识查询问题,大幅降低对人类专家的依赖
局限性 / 可改进方向¶
- 因果图质量依赖 LLM 背景知识的准确性,知识库偏差会传导到模型
- 因果结构学习的可扩展性受限于现有因果发现算法(GES 在高维场景下计算开销大)
- 编码器的分布外泛化能力不保证,可能影响 SCM 的 OOD 性能
- 未考虑隐混杂因素(hidden confounders),假设因果图为 DAG
- 实验数据集规模相对较小,大规模视觉任务(如 ImageNet)上的表现未验证
相关工作与启发¶
- vs CBM/CEM: CBM/CEM 采用扁平二部图结构,所有概念独立且直连任务;C2BM 引入因果图结构,概念间有层次化因果关系
- vs SCBM: SCBM 放松了概念独立假设但仅捕获关联关系;C2BM 捕获因果关系,干预效果更精准
- vs DiConStruct: DiConStruct 是事后方法,可能与 DNN 输出不对齐且仅用观测数据;C2BM 是设计时方法,结合背景知识
- 与可解释 AI 和公平性约束研究高度相关,因果瓶颈可直接用于实现算法公平
评分¶
- 新颖性: ⭐⭐⭐⭐ 将因果推理融入概念瓶颈是自然但重要的新方向
- 实验充分度: ⭐⭐⭐⭐ 涵盖精度/因果可靠性/干预/去偏/公平性五个维度,消融充分
- 写作质量: ⭐⭐⭐⭐ 形式化严谨,流水线描述清晰
- 价值: ⭐⭐⭐⭐⭐ 因果可靠的可解释模型对安全关键领域(医疗、法律)有重要实际意义