AP-OOD: Attention Pooling for Out-of-Distribution Detection¶
会议: ICLR 2026
arXiv: 2602.06031
代码: https://github.com/ml-jku/ap-ood
领域: OOD检测 / NLP安全
关键词: 分布外检测, 注意力池化, Mahalanobis距离, token级信息, 语言模型
一句话总结¶
提出AP-OOD,将Mahalanobis距离的均值池化替换为可学习的注意力池化,解决了均值池化丢失token级异常信息的问题,在文本OOD检测中将XSUM摘要的FPR95从27.84%降至4.67%,支持无监督到半监督的平滑过渡。
研究背景与动机¶
- 领域现状:语言模型部署时可能遇到OOD输入(如训练摘要BBC文章但收到CNN文章),会导致幻觉等不可靠输出。基于token嵌入的马氏距离是主流检测方法。
- 现有痛点:现有方法(如Ren et al. 2023)对序列的token嵌入做均值池化再计算马氏距离——但均值操作会隐藏异常信息。当ID和OOD序列的均值相近(但token分布不同)时,完全无法检测。Figure 1展示了这种失败模式。
- 核心矛盾:需要将变长表示(token序列)压缩为标量OOD分数,但简单聚合会丢失区分ID/OOD的关键token级模式。
- 本文要解决什么:设计一种超越均值池化的聚合方式,保留token级信息用于OOD检测。
- 切入角度:将马氏距离分解为方向分量→将每个方向的投影从均值改为注意力加权→让模型学习关注哪些token对OOD检测最有信息量。
- 核心idea一句话:用可学习的注意力池化替代均值池化来计算马氏距离,使OOD检测能利用token级信息。
方法详解¶
整体框架¶
AP-OOD从预训练编码器-解码器模型获取token嵌入 \(Z \in \mathbb{R}^{D \times S}\),用M个可学习的查询向量 \(w_j\) 通过注意力池化提取序列级表示,计算与语料原型的注意力池化距离作为OOD分数。
关键设计¶
- 注意力池化马氏距离
- 做什么:用注意力机制替代均值来聚合token嵌入
- 核心思路:标准马氏距离的方向分解 \(d^2 = \sum_j (w_j^T \bar{z} - w_j^T \mu)^2\),将均值池化 \(\bar{z} = \frac{1}{S}\sum_s z_s\) 替换为注意力池化 \(\bar{z} = Z \cdot \text{softmax}(\beta Z^T w)\)。这样每个方向\(w_j\)不仅定义了测量方向,还定义了关注哪些token
-
设计动机:Figure 1-2对比——均值池化时ID/OOD无法区分,但注意力池化能"看到"token级别的异常模式
-
多查询多头扩展
- 做什么:用矩阵\(W_j \in \mathbb{R}^{D \times T}\)代替向量\(w_j\),每个head有T个查询
- 核心思路:\(\bar{Z} = Z \cdot \text{softmax}(\beta Z^T W)\)(softmax在S×T的全矩阵上归一化),距离用Frobenius内积 \(\text{Tr}(W_j^T \bar{Z})\) 计算
-
设计动机:多查询可以捕获更丰富的token模式,提升检测能力
-
半监督扩展
- 做什么:当有少量OOD样本可用时平滑融入训练
- 核心思路:在损失函数中加入OOD样本的距离最大化项,通过系数控制无监督到有监督的过渡
- 设计动机:实际部署中可能有所了解的OOD类型,方法应能利用这些信息
训练策略¶
- 仅训练查询向量\(W_j\)(极少参数),编码器冻结
- 损失函数 \(\mathcal{L} = \frac{1}{N}\sum_i d^2(Z_i, \tilde{Z}) - \sum_j \log(\|W_j\|^2)\)
- mini-batch注意力池化降低内存消耗
- β=0时退化为标准马氏距离(理论保证)
实验关键数据¶
主实验¶
| 任务 | 指标 | 之前SOTA | AP-OOD |
|---|---|---|---|
| XSUM摘要 | FPR95↓ | 27.84% | 4.67% |
| WMT15 En→Fr | FPR95↓ | 77.08% | 70.37% |
FPR95改善幅度巨大(XSUM降低23+个百分点)。
消融¶
- β=0(退化为马氏距离)效果显著变差→注意力池化的贡献实质性的
- 增加head数M和查询数T均带来提升
- 半监督设置下少量OOD样本可进一步提升性能
关键发现¶
- 均值池化在摘要和翻译任务中都是主要瓶颈——OOD和ID的均值嵌入高度重叠
- 注意力池化学到的\(w\)趋向于关注序列中的"异常"token——这些token携带了最多的OOD信号
- 从无监督到半监督的过渡是平滑的——方法可以灵活适应可用OOD数据的多少
亮点与洞察¶
- "均值隐藏异常"这个问题的形式化(Figure 1-2)极其直观——一图胜千言
- 将注意力池化与马氏距离统一的理论框架很优雅——β=0退化为经典方法,β>0泛化到token级
- 参数量极少(仅学习查询向量),计算代价可忽略——真正的post-hoc方法
局限性 / 可改进方向¶
- 仅在摘要和翻译两个任务上验证——更多NLP任务(QA、对话等)待测
- 依赖预训练encoder-decoder架构——对decoder-only LLM的适用性需探索
- 仅处理输入端OOD——对生成端的分布偏移问题未涉及
- 注意力温度β的选择可能需要调参
相关工作与启发¶
- vs Ren et al. (2023): 均值池化+马氏距离的基线;AP-OOD用注意力池化直接替代均值
- vs 分类器OOD方法(MSP/Energy等): 这些假设分类头存在,AP-OOD适用于生成模型
- vs Mahalanobis距离(Lee et al. 2018): 经典图像OOD方法;AP-OOD将其扩展到序列数据
评分¶
- 新颖性: ⭐⭐⭐⭐ 注意力池化+马氏距离的结合自然但此前未被探索
- 实验充分度: ⭐⭐⭐ XSUM结果极强但实验范围较窄(2个任务)
- 写作质量: ⭐⭐⭐⭐⭐ Figure 1-2的说明性例子极好,理论推导清晰
- 价值: ⭐⭐⭐⭐ 为NLP-OOD检测提供了简单有效的改进思路