跳转至

Bridging Explainability and Embeddings: BEE Aware of Spuriousness

会议: ICLR 2026
arXiv: 2410.18970
代码: 公开可用
领域: AI安全 / 可解释性 / 鲁棒性
关键词: 虚假相关性检测, 权重空间分析, 嵌入几何, 线性探测, 基础模型

一句话总结

提出BEE框架,通过分析微调如何扰动预训练表征的权重空间几何结构,直接从分类器学到的权重中识别和命名虚假相关性(spurious correlations),无需反例样本即可发现隐藏的数据偏差,在ImageNet-1k上发现可导致准确率下降高达95%的虚假关联。

研究背景与动机

  1. 领域现状: 深度神经网络尤其是微调后的基础模型被广泛部署在医疗、金融等关键领域。虚假相关性(SC)会导致模型基于与任务无关的特征做决策,产生严重后果。检测SC是确保模型可靠性的关键。
  2. 现有痛点: 现有方法分两大类——数据驱动方法(如SpLiCE、Lg)分析数据集统计特征标记与类别关联的概念,但无法判断模型是否真的学到了这些关联;错误驱动方法(如B2T)从验证集错误推断SC,但依赖验证集中存在反例来暴露模型捷径。当反例缺失时(这在真实场景中很常见),这些方法都失效。
  3. 核心矛盾: 数据方法不看模型、错误方法需要反例,而实际中很多有害的SC恰恰是因为数据集中没有反例才存在的。现有的可解释方法(如CBM)需要预定义概念集并牺牲表达能力。根本问题是:如何在不依赖反例的情况下发现模型实际学到的虚假关联?
  4. 本文要解决什么: (1) 无需反例即可识别模型学到的SC;(2) 不仅检测还要命名具体是什么概念导致了SC;(3) 方法需适用于视觉和文本多种模态、多种基础模型。
  5. 切入角度: 关键观察——微调过程中,线性分类器的权重会从初始的类别名称嵌入(零样本权重)偏移,而偏移方向编码了模型学到的特征,包括虚假关联。由于权重和概念嵌入共享同一嵌入空间,可以通过几何关系直接分析哪些与类别无关的概念与权重高度相似。
  6. 核心idea一句话: 利用嵌入空间中分类权重相对于零样本初始化的漂移方向,识别与类别无关但与学到权重高度相似的概念作为虚假相关性。

方法详解

整体框架

BEE是一个权重空间诊断框架。输入为训练数据集、基础模型和概念集合,输出为每个类别学到的虚假相关概念列表。流程分两大步:(1) 在基础模型嵌入之上训练线性探测层,观察权重漂移;(2) 在嵌入空间中排名与漂移后权重相似但与类别无关的概念,自动筛选SC。

关键设计

  1. 权重初始化与漂移观察:
  2. 做什么:用类别名称的文本嵌入初始化线性层权重 \(w_k^0 = M(\text{class\_name}_k)\),训练后观察权重漂移 \(w_k^*\)
  3. 核心思路:零样本权重 \(w_k^0\) 编码了类别的"纯语义",微调后权重 \(w_k^*\) 则混合了真实特征和虚假特征。权重在嵌入空间中的漂移方向揭示了模型学到了什么。线性探测作为透明的诊断透镜,使分析变得可解释。
  4. 设计动机:线性层权重和概念嵌入处于同一空间,这使概念排名成为可能。用线性探测而非全参数微调,既确保了透明性,又通过实验证明其发现的SC在全参数微调模型中也持续存在。

  5. 概念提取与过滤(Step 2a):

  6. 做什么:从数据集中提取概念并过滤掉与类别相关的概念,只保留"类别中立"概念
  7. 核心思路:先用GIT-Large对图像生成描述(文本数据直接使用),再用YAKE关键词提取器取top-256 n-gram作为候选概念 \(C_{all}\)。然后用Llama-3.1-8B-Instruct过滤类别实例,再用WordNet的上下位词关系做二次过滤。
  8. 设计动机:只有与类别定义无关的概念才是SC的候选。例如"森林背景"与"陆地鸟"高度相关,但不应作为分类依据。两级过滤(LLM + WordNet)确保过滤的全面性。

  9. 概念排名(Step 2b):

  10. 做什么:对每个类别,根据概念嵌入与学到权重的相似度排名
  11. 核心思路:正相关SC评分 \(s_{k,i}^+ = w_k^{*\top} M(c_i) - \min_{k'} w_{k'}^{*\top} M(c_i)\)。直觉是寻找与某一类高度相似但与其他类不相似的概念。负相关SC使用不相似度 \(-w_k^{*\top} M(c_i)\)
  12. 设计动机:只看与单个类别的相似度可能产生假阳性(通用概念与所有类都相似),减去最小值消除了这种基线效应,确保只保留对特定类别有区分性的概念。

  13. 动态阈值(Step 2c):

  14. 做什么:自动确定每个类别保留多少个SC
  15. 核心思路:对排序后的分数做均值滤波(窗口大小 \(r\)),找到平滑曲线偏离首尾连线最大的拐点 \(m_k = \lfloor r/2 \rfloor + \arg\max_i (\bar{s}_{k,1} - i \frac{\bar{s}_{k,1} - \bar{s}_{k,p}}{p-1} - \bar{s}_{k,i})\)。这相当于在排序分数曲线上找最大偏差点。
  16. 设计动机:不同类别可能有不同数量的SC,硬性设置top-k不合理。动态阈值自适应地为每个类别选择合适数量的SC,无需人工调参。

  17. SC正则化(下游应用):

  18. 做什么:利用发现的SC构建正则化项来提升模型鲁棒性
  19. 核心思路:约束分类权重与SC概念等距,正则化损失 \(\mathcal{L}_{reg}(b) = \frac{\tau^2}{N} \sum_{k=1}^N [w_k^\top M(b) - sg(\frac{1}{N}\sum_j w_j^\top M(b))]^2\),总损失 \(\mathcal{L} = \mathcal{L}_{ERM} + \alpha \frac{1}{|\mathcal{B}|} \sum_{b \in \mathcal{B}} \mathcal{L}_{reg}(b)\)
  20. 设计动机:在完全虚假相关的极端设置中(训练集无反例),GroupDRO失效,而SC正则化通过显式约束降低对SC的依赖。

训练策略

  • 使用AdamW优化器(\(lr=1e\)-4, \(wd=1e\)-5),batch size 1024
  • 交叉熵损失+类别平衡权重,使用CLIP温度 \(\tau=100\) 缩放logits
  • 每次更新后权重归一化,基于验证集类别平衡准确率做早停

实验关键数据

主实验:SC增强的零样本提示

方法 Waterbirds Worst Waterbirds Avg CelebA Worst CelebA Avg CivilComments Worst
Basic zero-shot 35.2 84.2 72.8 87.7 33.1
w/ B2T 48.1 86.1 72.8 88.0 -
w/ SpLiCE 48.1 82.5 67.2 90.2 -
w/ Lg 46.1 85.9 50.6 87.2 -
w/ BEE 50.3 86.3 73.1 85.7 53.2

BEE在Waterbirds和CivilComments上的worst-group准确率显著优于所有竞争方法。

ImageNet-1k SC影响量化

正确类别 虚假概念 诱导类别 正确类识别率变化 诱导类预测率
Peafowl firemen Fire truck 100% → 5.3% (-94.7%) 0% → 93.4%
Mexican Hairless Dog reading newspaper Crossword 47.5% → 0.9% (-46.6%) 0% → 36.6%
Bernese Mountain Dog shrimp American lobster 99.8% → 10.6% (-89.2%) 0% → 37.2%

完全虚假设置下的正则化实验

方法 Waterbirds Worst CelebA Worst CivilComments Worst
ERM 43.2±5.7 9.6±1.0 18.6±0.3
GroupDRO 38.9±5.4 8.1±0.3 18.7±0.4
Reg w/ random SCs 46.6±2.7 9.4±0.0 19.1±1.6
Reg w/ Lg's SCs 50.4±0.1 8.3±0.0 -
Reg w/ BEE's SCs 57.9±0.3 10.4±0.5 31.3±0.7

在无反例的极端设置中,GroupDRO甚至不如ERM,但BEE的SC正则化持续改善worst-group表现。

关键发现

  • SC跨模型迁移:BEE在CLIP上发现的SC在AlexNet、ResNet50、ViT-L/16等多种架构上都导致显著性能下降,表明SC是数据集的属性而非模型的属性
  • MIMIC-CXR医疗笔记中的危险捷径:BEE发现"chest examination"和"chest radiograph"是"无病理发现"类的SC,添加这类词会使分类器偏向"无发现",在医疗场景中可能导致漏诊
  • 无需反例的SC发现:在移除所有少数组样本的完全虚假设置中,BEE仍能有效识别SC,而基于错误分析的方法完全失效

亮点与洞察

  • 权重空间分析是全新的SC检测范式:不看数据分布也不看预测错误,直接从分类器权重的几何漂移推断学到了什么。这个思路利用了嵌入空间的对齐性质,非常优雅,且可以发现传统方法看不到的SC。
  • 线性探测作为诊断透镜:选择最简单的分类器避免了复杂模型的不可解释性,同时实验证明发现的SC在全参数微调模型中同样存在并可迁移,说明线性探测的发现具有普适性。
  • 动态阈值的拐点检测方法:自动为每个类确定SC数量,无需人工调参,使方法在ImageNet的1000个类上自动化运行,具有很好的可扩展性。
  • MIMIC-CXR发现的实际安全意义:在医疗文本中发现的SC直接指向了可能导致漏诊的模型缺陷,展示了方法在高风险领域的实际价值。

局限性 / 可改进方向

  • 依赖线性探测假设——如果SC以非线性方式编码,可能无法检测到
  • 概念提取依赖YAKE+GIT-Large,概念覆盖范围受限于captioning模型的描述能力
  • 当前仅针对分类任务,能否扩展到检测/分割/生成等更复杂任务的SC检测?
  • SC正则化需要已知SC集合,能否将检测和缓解做成闭环迭代?
  • 在CelebA-blonde hair上BEE和B2T都未检测到SC,可能存在某些类型的短路特征检测盲区

相关工作与启发

  • vs B2T: B2T从验证错误推断SC,需要反例存在。BEE从权重漂移推断,不需要反例,发现的概念范围更广。在Waterbirds上BEE(worst 50.3%)优于B2T(48.1%)。
  • vs SpLiCE/Lg: 这些是数据驱动方法,分析数据集中的概念分布,但无法确认模型是否真的学到了这些关联。BEE直接分析模型权重,确保发现的是模型实际学到的SC。
  • vs CBM: 概念瓶颈模型需要预定义概念并修改模型架构,牺牲表达能力。BEE对模型无任何修改,分析的是原始SOTA模型。

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 从权重空间几何分析SC是全新视角,理论动机清晰,方法设计优雅
  • 实验充分度: ⭐⭐⭐⭐⭐ 覆盖视觉+文本、5种嵌入模型、5个数据集、定量+定性+生成验证,非常全面
  • 写作质量: ⭐⭐⭐⭐ 结构清晰,图示直观,但部分数学符号较密需要反复阅读
  • 价值: ⭐⭐⭐⭐⭐ 在AI安全和可信AI领域有重要意义,MIMIC-CXR的发现直接关系到医疗安全