跳转至

AP-OOD: Attention Pooling for Out-of-Distribution Detection

会议: ICLR 2026
arXiv: 2602.06031
代码: https://github.com/ml-jku/ap-ood
领域: OOD检测 / NLP安全
关键词: 分布外检测, 注意力池化, Mahalanobis距离, token级信息, 语言模型

一句话总结

提出AP-OOD,将Mahalanobis距离的均值池化替换为可学习的注意力池化,解决了均值池化丢失token级异常信息的问题,在文本OOD检测中将XSUM摘要的FPR95从27.84%降至4.67%,支持无监督到半监督的平滑过渡。

研究背景与动机

  1. 领域现状:语言模型部署时可能遇到OOD输入(如训练摘要BBC文章但收到CNN文章),会导致幻觉等不可靠输出。基于token嵌入的马氏距离是主流检测方法。
  2. 现有痛点:现有方法(如Ren et al. 2023)对序列的token嵌入做均值池化再计算马氏距离——但均值操作会隐藏异常信息。当ID和OOD序列的均值相近(但token分布不同)时,完全无法检测。Figure 1展示了这种失败模式。
  3. 核心矛盾:需要将变长表示(token序列)压缩为标量OOD分数,但简单聚合会丢失区分ID/OOD的关键token级模式。
  4. 本文要解决什么:设计一种超越均值池化的聚合方式,保留token级信息用于OOD检测。
  5. 切入角度:将马氏距离分解为方向分量→将每个方向的投影从均值改为注意力加权→让模型学习关注哪些token对OOD检测最有信息量。
  6. 核心idea一句话:用可学习的注意力池化替代均值池化来计算马氏距离,使OOD检测能利用token级信息。

方法详解

整体框架

AP-OOD从预训练编码器-解码器模型获取token嵌入 \(Z \in \mathbb{R}^{D \times S}\),用M个可学习的查询向量 \(w_j\) 通过注意力池化提取序列级表示,计算与语料原型的注意力池化距离作为OOD分数。

关键设计

  1. 注意力池化马氏距离
  2. 做什么:用注意力机制替代均值来聚合token嵌入
  3. 核心思路:标准马氏距离的方向分解 \(d^2 = \sum_j (w_j^T \bar{z} - w_j^T \mu)^2\),将均值池化 \(\bar{z} = \frac{1}{S}\sum_s z_s\) 替换为注意力池化 \(\bar{z} = Z \cdot \text{softmax}(\beta Z^T w)\)。这样每个方向\(w_j\)不仅定义了测量方向,还定义了关注哪些token
  4. 设计动机:Figure 1-2对比——均值池化时ID/OOD无法区分,但注意力池化能"看到"token级别的异常模式

  5. 多查询多头扩展

  6. 做什么:用矩阵\(W_j \in \mathbb{R}^{D \times T}\)代替向量\(w_j\),每个head有T个查询
  7. 核心思路:\(\bar{Z} = Z \cdot \text{softmax}(\beta Z^T W)\)(softmax在S×T的全矩阵上归一化),距离用Frobenius内积 \(\text{Tr}(W_j^T \bar{Z})\) 计算
  8. 设计动机:多查询可以捕获更丰富的token模式,提升检测能力

  9. 半监督扩展

  10. 做什么:当有少量OOD样本可用时平滑融入训练
  11. 核心思路:在损失函数中加入OOD样本的距离最大化项,通过系数控制无监督到有监督的过渡
  12. 设计动机:实际部署中可能有所了解的OOD类型,方法应能利用这些信息

训练策略

  • 仅训练查询向量\(W_j\)(极少参数),编码器冻结
  • 损失函数 \(\mathcal{L} = \frac{1}{N}\sum_i d^2(Z_i, \tilde{Z}) - \sum_j \log(\|W_j\|^2)\)
  • mini-batch注意力池化降低内存消耗
  • β=0时退化为标准马氏距离(理论保证)

实验关键数据

主实验

任务 指标 之前SOTA AP-OOD
XSUM摘要 FPR95↓ 27.84% 4.67%
WMT15 En→Fr FPR95↓ 77.08% 70.37%

FPR95改善幅度巨大(XSUM降低23+个百分点)。

消融

  • β=0(退化为马氏距离)效果显著变差→注意力池化的贡献实质性的
  • 增加head数M和查询数T均带来提升
  • 半监督设置下少量OOD样本可进一步提升性能

关键发现

  • 均值池化在摘要和翻译任务中都是主要瓶颈——OOD和ID的均值嵌入高度重叠
  • 注意力池化学到的\(w\)趋向于关注序列中的"异常"token——这些token携带了最多的OOD信号
  • 从无监督到半监督的过渡是平滑的——方法可以灵活适应可用OOD数据的多少

亮点与洞察

  • "均值隐藏异常"这个问题的形式化(Figure 1-2)极其直观——一图胜千言
  • 将注意力池化与马氏距离统一的理论框架很优雅——β=0退化为经典方法,β>0泛化到token级
  • 参数量极少(仅学习查询向量),计算代价可忽略——真正的post-hoc方法

局限性 / 可改进方向

  • 仅在摘要和翻译两个任务上验证——更多NLP任务(QA、对话等)待测
  • 依赖预训练encoder-decoder架构——对decoder-only LLM的适用性需探索
  • 仅处理输入端OOD——对生成端的分布偏移问题未涉及
  • 注意力温度β的选择可能需要调参

相关工作与启发

  • vs Ren et al. (2023): 均值池化+马氏距离的基线;AP-OOD用注意力池化直接替代均值
  • vs 分类器OOD方法(MSP/Energy等): 这些假设分类头存在,AP-OOD适用于生成模型
  • vs Mahalanobis距离(Lee et al. 2018): 经典图像OOD方法;AP-OOD将其扩展到序列数据

评分

  • 新颖性: ⭐⭐⭐⭐ 注意力池化+马氏距离的结合自然但此前未被探索
  • 实验充分度: ⭐⭐⭐ XSUM结果极强但实验范围较窄(2个任务)
  • 写作质量: ⭐⭐⭐⭐⭐ Figure 1-2的说明性例子极好,理论推导清晰
  • 价值: ⭐⭐⭐⭐ 为NLP-OOD检测提供了简单有效的改进思路