跳转至

FlexAttention: 面向高效高分辨率视觉语言模型的灵活注意力机制

会议: ECCV 2024
arXiv: 2407.20228
代码: https://vis-www.cs.umass.edu/flexattention
领域: 多模态VLM
关键词: 高分辨率VLM, 注意力机制, 动态token选择, 层次化自注意力, 计算效率

一句话总结

提出 FlexAttention,通过基于注意力图的高分辨率token动态选择和层次化自注意力融合机制,在保持甚至超越现有高分辨率VLM性能的同时,将计算成本降低近40%。

研究背景与动机

  1. 领域现状:主流VLM(如LLaVA-1.5、InstructBLIP)通常将输入图像限制在较低分辨率(224×224或336×336),因为它们依赖CLIP等固定分辨率的视觉编码器。这在需要识别小文字、小物体等细节的场景下表现很差。

  2. 现有痛点:已有的高分辨率VLM(如LLaVA-1.5-HD、CogAgent)虽然能接收高分辨率图像,但它们将所有高分辨率token全部送入注意力模块计算,导致计算代价随token数量平方级增长。例如分辨率从336提升到1008,token数量增加9倍,注意力计算量增加约81倍。

  3. 核心矛盾:高分辨率带来的视觉细节信息 vs. 自注意力\(O(N^2)\)复杂度导致的计算开销之间存在根本性的 trade-off。现有方法要么牺牲分辨率,要么承受巨大的计算代价。

  4. 人类视觉启发:人类视觉处理并非一次性记住所有像素细节,而是先维持一个粗略的整体表征,在受到外部刺激时才对感兴趣的区域进行更精细的关注(选择性注意力机制)。

  5. 切入角度:作者观察到,在VLM生成过程中,模型的注意力图天然地揭示了当前哪些图像区域是重要的。可以利用这个"免费"的信号来动态选择只需要高分辨率细节的区域,而非暴力处理全部高分辨率token。

  6. 核心idea一句话:利用注意力图动态选择约10%的关键高分辨率token,通过层次化自注意力将其融入低分辨率表征,实现计算量亚线性增长下的高分辨率感知。

方法详解

整体框架

FlexAttention 可以即插即用地替换现有VLM的自注意力模块。整体pipeline如下:

  • 输入:一张高分辨率图像 \(I_{HR}\)(如1008×1008)和文本输入 \(T\)
  • 双路编码:将高分辨率图像同时编码为高分辨率token \(f_{HR}\) 和低分辨率token \(f_{LR}\)(下采样后编码)
  • 前半段(\(N_{SA}\)层):只使用低分辨率image token和text token进行标准自注意力计算,粗略理解图像全貌
  • 后半段(\(N_{FA}\)层):启用FlexAttention,每层动态选择少量高分辨率token,通过层次化自注意力融合高分辨率细节
  • 输出:最后一层hidden state经过projector生成文本答案

这种设计的关键在于:前几层用低成本建立全局理解,后续层才引入高分辨率细节,且每层只引入一小部分(~10%)的高分辨率token。

关键设计

1. 高分辨率特征选择模块(High-Resolution Feature Selection Module)

  • 做什么:根据当前层的注意力图,从全部高分辨率token中选择与当前生成最相关的一小部分
  • 核心思路
  • 取注意力图 \(Map\) 最后一列的前 \(N_i\) 个值(即所有低分辨率image token对最后一个text token的注意力权重),这些值反映了模型在生成下一个token时对各图像区域的关注程度
  • 将这个1D注意力向量 reshape 为2D空间注意力图
  • 对该注意力图进行归一化和二值化,得到一个binary mask
  • 将该mask上采样到高分辨率feature map的空间尺寸,形成高分辨率选择mask
  • 应用mask选出被"激活"的高分辨率token \(f_{SHR}\),约占总高分辨率token的10%
  • 设计动机:自注意力图本身就是一个"免费"的区域重要性信号,利用它来选择token避免了额外的选择网络开销。这种选择是逐层动态变化的,不同层可以关注不同的图像区域

2. 层次化自注意力模块(Hierarchical Self-Attention Module)

  • 做什么:将选出的高分辨率token信息融合到原始hidden state(低分辨率token + text token)中
  • 核心思路
  • Query \(Q\) 只来自原始hidden state \(H\)\(Q = HW_Q\)
  • Key和Value通过拼接原始hidden state和选中的高分辨率token构建:
    • \(K_{all} = \text{Concat}(HW_K, f_{SHR}W_K')\)
    • \(V_{all} = \text{Concat}(HW_V, f_{SHR}W_V')\)
  • 注意高分辨率token使用独立的投影矩阵 \(W_K'\), \(W_V'\),而非共享Query投影
  • 输出注意力图 \(Map'\) 尺寸为 \(N \times (N+M)\),截取前 \(N \times N\) 部分传递给下一层做token选择
  • 设计动机:这种设计保证hidden state的维度不变(仍为 \(N \times D\)),高分辨率token只参与K/V的计算而不引入新的Query,计算复杂度从 \(O((M+N)^2D)\) 降为 \(O((M+N)ND)\),实现了对高分辨率token的线性而非平方级计算增长

3. 逐层迭代选择机制

  • 做什么:高分辨率选择和层次化注意力在每个FlexAttention层中交替执行
  • 核心思路:第 \(i\) 层的注意力图 \(Map^i\) 用于选择第 \(i+1\) 层的高分辨率token \(f_{SHR}^i\),形成一个迭代refinement过程
  • 设计动机:随着网络深度增加,模型对图像的理解逐步深入,每层关注的区域可能不同。逐层迭代允许模型在不同层"聚焦"到不同的细节区域

训练策略

  • 基于预训练的 LLaVA-1.5-7b 权重初始化
  • 在 LLaVA-1.5-7b 的微调数据集上训练1个epoch
  • batch size = 1152,学习率 2e-5,cosine scheduler
  • 高分辨率输入设为1008×1008(原始分辨率的3倍)
  • 所有评估均为 zero-shot

实验关键数据

主实验:高分辨率VQA基准

模型 分辨率 V* Bench Overall V* Bench Spatial MagnifierBench TextVQA RSVQA Overall
InstructBLIP 224² 34.0 47.4 5.6 - -
LLaVA-1.5-7b 336² 47.6 56.6 26.8 46.0 68.4
LLaVA-HD 448² 51.8 61.8 35.0 45.6 68.4
LLaVA-XAttn 1008² 48.2 56.6 32.2 45.5 71.1
LLaVA-FlexAttn 1008² 54.5 64.5 35.0 48.9 72.7
GPT-4V - 55.0 60.5 - - -

FlexAttention在V Bench上比base model提升6.9%,比LLaVA-HD提升2.7%,在TextVQA上比LLaVA-HD提升3.3%,在RSVQA上超过了专门为遥感设计的GeoChat(72.7% vs 72.3%)。在V Bench Spatial类别上甚至超过了GPT-4V(64.5% vs 60.5%)。

计算效率对比

模型 MagnifierBench TFLOPs MagnifierBench Time(s) TextVQA TFLOPs TextVQA Time(s)
LLaVA-HD 24.9 154 24.5 3273
LLaVA-XAttn 27.1 178 26.7 3741
LLaVA-FlexAttn 17.1 112 17.1 2839

FlexAttention的TFLOPs比LLaVA-HD降低约31%,比LLaVA-XAttn降低约37%。实际推理时间在MagnifierBench上快约28-37%,在TextVQA上快约13-24%。

消融实验

选择策略 MagnifierBench TextVQA
Random(随机选择10%) 31.4 44.5
Center(选中心区域) 30.7 45.9
Attn. Map(注意力图选择) 35.0 48.9
高分辨率尺寸 MagnifierBench TextVQA TFLOPs
672×672 (2x) ~32 ~45 ~13
1008×1008 (3x) 35.0 48.9 17.1
1344×1344 (4x) ~36 ~48.9 ~23
RefCOCO子集 LLaVA-1.5 LLaVA-FlexAttn 提升
大物体 75.9 78.8 +2.9
小物体 41.3 51.3 +10.0
整体 75.4 78.4 +3.0

关键发现

  • 注意力图选择远优于随机/中心选择:在TextVQA上,注意力图选择比随机选择高出4.4%,证明了基于注意力的动态选择的有效性
  • 分辨率提升呈边际递减:从672→1008性能提升显著,但1008→1344在TextVQA上几乎无提升,因为TextVQA图像平均分辨率约950×811,超出原图分辨率后收益消失
  • 小物体场景增益显著:在RefCOCO上,小物体精度提升10.0%(远高于大物体的2.9%),精确验证了高分辨率输入对细粒度视觉推理的价值
  • 通用能力几乎无损:在POPE、GQA、VQAv2等通用基准上与base model持平,说明FlexAttention的引入不干扰模型原有能力

亮点与洞察

  • 注意力图的"免费复用":巧妙地将自注意力图本身作为高分辨率token的选择信号,无需额外的选择网络或learnable gating,这个设计既简洁又高效。注意力图本来就要计算,复用为选择信号几乎零额外开销
  • 层次化注意力的不对称设计:高分辨率token只参与K/V而不生成Query,保证hidden state维度不变的同时实现了线性计算增长。这个设计可以迁移到其他需要融合多粒度信息的场景(如视频多帧融合、多传感器融合)
  • 人类视觉的计算模拟:从认知科学中选择性注意力理论获得灵感,将"先粗看再聚焦"的自然视觉过程转化为具体的计算架构(前几层粗略理解 + 后续层动态关注),这种top-down + bottom-up的双路设计思路值得借鉴
  • 即插即用设计:FlexAttention作为替换现有自注意力的模块,可以应用到多种VLM架构,不限于LLaVA

局限性

  • 选择比例固定为~10%:论文未深入讨论不同任务是否需要不同的选择比例。对于OCR等需要全图细节的任务,10%可能不够;对于只关注单个物体的任务,10%可能冗余
  • 仅在LLaVA-1.5-7b上验证:未在更大的模型(13B/70B)或不同架构(QFormer-based、Fuyu-style)上验证通用性
  • 二值化选择的信息损失:注意力图经过二值化后变成硬选择(选/不选),丢失了注意力值的连续梯度信息。软选择或加权选择可能能进一步提升性能
  • 仅处理静态图像:论文提到可以扩展到视频/音频等长序列模态,但未做实验验证
  • LLaVA-HD的公平性:LLaVA-HD的输入分辨率为448×448,而FlexAttention为1008×1008,分辨率本身就不同,对比不完全公平(虽然FlexAttention的计算量更低)

相关工作与启发

  • vs LLaVA-1.5-HD:LLaVA-HD将高分辨率token直接拼接到序列中参与全注意力计算,简单但计算开销大。FlexAttention通过动态选择和层次化注意力避免了全量计算
  • vs CogAgent:CogAgent使用交叉注意力在每层计算hidden state与全部高分辨率特征的dense对应关系,计算量仍然很大(需要处理全部高分辨率token的K/V)。FlexAttention先选后算,更加高效
  • vs Sparse Attention方法(BigBird、Reformer等):这些方法通过稀疏化注意力矩阵来降低复杂度,但是domain-agnostic的。FlexAttention利用了视觉token的空间结构和注意力图的语义信息进行有针对性的稀疏化
  • 启发:这种"先粗后细"的动态选择范式可以迁移到视频理解(先看关键帧再看细节帧)、长文档理解(先看摘要再看关键段落)等场景

评分

  • 新颖性: ⭐⭐⭐⭐ 利用注意力图做动态token选择的思路简洁优雅,层次化注意力设计合理,但整体思路在efficient attention领域尚不算颠覆性创新
  • 实验充分度: ⭐⭐⭐⭐ 覆盖了高分辨率/通用/领域专用多类基准,消融实验充分(选择策略、分辨率、物体大小),但缺少更大模型和更多架构的验证
  • 写作质量: ⭐⭐⭐⭐ 论文结构清晰,动机阐述流畅,图示直观,但相关工作部分对CogAgent的对比实验设置说明可以更清楚
  • 价值: ⭐⭐⭐⭐ 在高分辨率VLM这个实际痛点上提出了有效的效率-性能trade-off方案,40%计算量降低+性能提升是实用的工程价值