FlexAttention: 面向高效高分辨率视觉语言模型的灵活注意力机制¶
会议: ECCV 2024
arXiv: 2407.20228
代码: https://vis-www.cs.umass.edu/flexattention
领域: 多模态VLM
关键词: 高分辨率VLM, 注意力机制, 动态token选择, 层次化自注意力, 计算效率
一句话总结¶
提出 FlexAttention,通过基于注意力图的高分辨率token动态选择和层次化自注意力融合机制,在保持甚至超越现有高分辨率VLM性能的同时,将计算成本降低近40%。
研究背景与动机¶
-
领域现状:主流VLM(如LLaVA-1.5、InstructBLIP)通常将输入图像限制在较低分辨率(224×224或336×336),因为它们依赖CLIP等固定分辨率的视觉编码器。这在需要识别小文字、小物体等细节的场景下表现很差。
-
现有痛点:已有的高分辨率VLM(如LLaVA-1.5-HD、CogAgent)虽然能接收高分辨率图像,但它们将所有高分辨率token全部送入注意力模块计算,导致计算代价随token数量平方级增长。例如分辨率从336提升到1008,token数量增加9倍,注意力计算量增加约81倍。
-
核心矛盾:高分辨率带来的视觉细节信息 vs. 自注意力\(O(N^2)\)复杂度导致的计算开销之间存在根本性的 trade-off。现有方法要么牺牲分辨率,要么承受巨大的计算代价。
-
人类视觉启发:人类视觉处理并非一次性记住所有像素细节,而是先维持一个粗略的整体表征,在受到外部刺激时才对感兴趣的区域进行更精细的关注(选择性注意力机制)。
-
切入角度:作者观察到,在VLM生成过程中,模型的注意力图天然地揭示了当前哪些图像区域是重要的。可以利用这个"免费"的信号来动态选择只需要高分辨率细节的区域,而非暴力处理全部高分辨率token。
-
核心idea一句话:利用注意力图动态选择约10%的关键高分辨率token,通过层次化自注意力将其融入低分辨率表征,实现计算量亚线性增长下的高分辨率感知。
方法详解¶
整体框架¶
FlexAttention 可以即插即用地替换现有VLM的自注意力模块。整体pipeline如下:
- 输入:一张高分辨率图像 \(I_{HR}\)(如1008×1008)和文本输入 \(T\)
- 双路编码:将高分辨率图像同时编码为高分辨率token \(f_{HR}\) 和低分辨率token \(f_{LR}\)(下采样后编码)
- 前半段(\(N_{SA}\)层):只使用低分辨率image token和text token进行标准自注意力计算,粗略理解图像全貌
- 后半段(\(N_{FA}\)层):启用FlexAttention,每层动态选择少量高分辨率token,通过层次化自注意力融合高分辨率细节
- 输出:最后一层hidden state经过projector生成文本答案
这种设计的关键在于:前几层用低成本建立全局理解,后续层才引入高分辨率细节,且每层只引入一小部分(~10%)的高分辨率token。
关键设计¶
1. 高分辨率特征选择模块(High-Resolution Feature Selection Module)¶
- 做什么:根据当前层的注意力图,从全部高分辨率token中选择与当前生成最相关的一小部分
- 核心思路:
- 取注意力图 \(Map\) 最后一列的前 \(N_i\) 个值(即所有低分辨率image token对最后一个text token的注意力权重),这些值反映了模型在生成下一个token时对各图像区域的关注程度
- 将这个1D注意力向量 reshape 为2D空间注意力图
- 对该注意力图进行归一化和二值化,得到一个binary mask
- 将该mask上采样到高分辨率feature map的空间尺寸,形成高分辨率选择mask
- 应用mask选出被"激活"的高分辨率token \(f_{SHR}\),约占总高分辨率token的10%
- 设计动机:自注意力图本身就是一个"免费"的区域重要性信号,利用它来选择token避免了额外的选择网络开销。这种选择是逐层动态变化的,不同层可以关注不同的图像区域
2. 层次化自注意力模块(Hierarchical Self-Attention Module)¶
- 做什么:将选出的高分辨率token信息融合到原始hidden state(低分辨率token + text token)中
- 核心思路:
- Query \(Q\) 只来自原始hidden state \(H\):\(Q = HW_Q\)
- Key和Value通过拼接原始hidden state和选中的高分辨率token构建:
- \(K_{all} = \text{Concat}(HW_K, f_{SHR}W_K')\)
- \(V_{all} = \text{Concat}(HW_V, f_{SHR}W_V')\)
- 注意高分辨率token使用独立的投影矩阵 \(W_K'\), \(W_V'\),而非共享Query投影
- 输出注意力图 \(Map'\) 尺寸为 \(N \times (N+M)\),截取前 \(N \times N\) 部分传递给下一层做token选择
- 设计动机:这种设计保证hidden state的维度不变(仍为 \(N \times D\)),高分辨率token只参与K/V的计算而不引入新的Query,计算复杂度从 \(O((M+N)^2D)\) 降为 \(O((M+N)ND)\),实现了对高分辨率token的线性而非平方级计算增长
3. 逐层迭代选择机制¶
- 做什么:高分辨率选择和层次化注意力在每个FlexAttention层中交替执行
- 核心思路:第 \(i\) 层的注意力图 \(Map^i\) 用于选择第 \(i+1\) 层的高分辨率token \(f_{SHR}^i\),形成一个迭代refinement过程
- 设计动机:随着网络深度增加,模型对图像的理解逐步深入,每层关注的区域可能不同。逐层迭代允许模型在不同层"聚焦"到不同的细节区域
训练策略¶
- 基于预训练的 LLaVA-1.5-7b 权重初始化
- 在 LLaVA-1.5-7b 的微调数据集上训练1个epoch
- batch size = 1152,学习率 2e-5,cosine scheduler
- 高分辨率输入设为1008×1008(原始分辨率的3倍)
- 所有评估均为 zero-shot
实验关键数据¶
主实验:高分辨率VQA基准¶
| 模型 | 分辨率 | V* Bench Overall | V* Bench Spatial | MagnifierBench | TextVQA | RSVQA Overall |
|---|---|---|---|---|---|---|
| InstructBLIP | 224² | 34.0 | 47.4 | 5.6 | - | - |
| LLaVA-1.5-7b | 336² | 47.6 | 56.6 | 26.8 | 46.0 | 68.4 |
| LLaVA-HD | 448² | 51.8 | 61.8 | 35.0 | 45.6 | 68.4 |
| LLaVA-XAttn | 1008² | 48.2 | 56.6 | 32.2 | 45.5 | 71.1 |
| LLaVA-FlexAttn | 1008² | 54.5 | 64.5 | 35.0 | 48.9 | 72.7 |
| GPT-4V | - | 55.0 | 60.5 | - | - | - |
FlexAttention在V Bench上比base model提升6.9%,比LLaVA-HD提升2.7%,在TextVQA上比LLaVA-HD提升3.3%,在RSVQA上超过了专门为遥感设计的GeoChat(72.7% vs 72.3%)。在V Bench Spatial类别上甚至超过了GPT-4V(64.5% vs 60.5%)。
计算效率对比¶
| 模型 | MagnifierBench TFLOPs | MagnifierBench Time(s) | TextVQA TFLOPs | TextVQA Time(s) |
|---|---|---|---|---|
| LLaVA-HD | 24.9 | 154 | 24.5 | 3273 |
| LLaVA-XAttn | 27.1 | 178 | 26.7 | 3741 |
| LLaVA-FlexAttn | 17.1 | 112 | 17.1 | 2839 |
FlexAttention的TFLOPs比LLaVA-HD降低约31%,比LLaVA-XAttn降低约37%。实际推理时间在MagnifierBench上快约28-37%,在TextVQA上快约13-24%。
消融实验¶
| 选择策略 | MagnifierBench | TextVQA |
|---|---|---|
| Random(随机选择10%) | 31.4 | 44.5 |
| Center(选中心区域) | 30.7 | 45.9 |
| Attn. Map(注意力图选择) | 35.0 | 48.9 |
| 高分辨率尺寸 | MagnifierBench | TextVQA | TFLOPs |
|---|---|---|---|
| 672×672 (2x) | ~32 | ~45 | ~13 |
| 1008×1008 (3x) | 35.0 | 48.9 | 17.1 |
| 1344×1344 (4x) | ~36 | ~48.9 | ~23 |
| RefCOCO子集 | LLaVA-1.5 | LLaVA-FlexAttn | 提升 |
|---|---|---|---|
| 大物体 | 75.9 | 78.8 | +2.9 |
| 小物体 | 41.3 | 51.3 | +10.0 |
| 整体 | 75.4 | 78.4 | +3.0 |
关键发现¶
- 注意力图选择远优于随机/中心选择:在TextVQA上,注意力图选择比随机选择高出4.4%,证明了基于注意力的动态选择的有效性
- 分辨率提升呈边际递减:从672→1008性能提升显著,但1008→1344在TextVQA上几乎无提升,因为TextVQA图像平均分辨率约950×811,超出原图分辨率后收益消失
- 小物体场景增益显著:在RefCOCO上,小物体精度提升10.0%(远高于大物体的2.9%),精确验证了高分辨率输入对细粒度视觉推理的价值
- 通用能力几乎无损:在POPE、GQA、VQAv2等通用基准上与base model持平,说明FlexAttention的引入不干扰模型原有能力
亮点与洞察¶
- 注意力图的"免费复用":巧妙地将自注意力图本身作为高分辨率token的选择信号,无需额外的选择网络或learnable gating,这个设计既简洁又高效。注意力图本来就要计算,复用为选择信号几乎零额外开销
- 层次化注意力的不对称设计:高分辨率token只参与K/V而不生成Query,保证hidden state维度不变的同时实现了线性计算增长。这个设计可以迁移到其他需要融合多粒度信息的场景(如视频多帧融合、多传感器融合)
- 人类视觉的计算模拟:从认知科学中选择性注意力理论获得灵感,将"先粗看再聚焦"的自然视觉过程转化为具体的计算架构(前几层粗略理解 + 后续层动态关注),这种top-down + bottom-up的双路设计思路值得借鉴
- 即插即用设计:FlexAttention作为替换现有自注意力的模块,可以应用到多种VLM架构,不限于LLaVA
局限性¶
- 选择比例固定为~10%:论文未深入讨论不同任务是否需要不同的选择比例。对于OCR等需要全图细节的任务,10%可能不够;对于只关注单个物体的任务,10%可能冗余
- 仅在LLaVA-1.5-7b上验证:未在更大的模型(13B/70B)或不同架构(QFormer-based、Fuyu-style)上验证通用性
- 二值化选择的信息损失:注意力图经过二值化后变成硬选择(选/不选),丢失了注意力值的连续梯度信息。软选择或加权选择可能能进一步提升性能
- 仅处理静态图像:论文提到可以扩展到视频/音频等长序列模态,但未做实验验证
- LLaVA-HD的公平性:LLaVA-HD的输入分辨率为448×448,而FlexAttention为1008×1008,分辨率本身就不同,对比不完全公平(虽然FlexAttention的计算量更低)
相关工作与启发¶
- vs LLaVA-1.5-HD:LLaVA-HD将高分辨率token直接拼接到序列中参与全注意力计算,简单但计算开销大。FlexAttention通过动态选择和层次化注意力避免了全量计算
- vs CogAgent:CogAgent使用交叉注意力在每层计算hidden state与全部高分辨率特征的dense对应关系,计算量仍然很大(需要处理全部高分辨率token的K/V)。FlexAttention先选后算,更加高效
- vs Sparse Attention方法(BigBird、Reformer等):这些方法通过稀疏化注意力矩阵来降低复杂度,但是domain-agnostic的。FlexAttention利用了视觉token的空间结构和注意力图的语义信息进行有针对性的稀疏化
- 启发:这种"先粗后细"的动态选择范式可以迁移到视频理解(先看关键帧再看细节帧)、长文档理解(先看摘要再看关键段落)等场景
评分¶
- 新颖性: ⭐⭐⭐⭐ 利用注意力图做动态token选择的思路简洁优雅,层次化注意力设计合理,但整体思路在efficient attention领域尚不算颠覆性创新
- 实验充分度: ⭐⭐⭐⭐ 覆盖了高分辨率/通用/领域专用多类基准,消融实验充分(选择策略、分辨率、物体大小),但缺少更大模型和更多架构的验证
- 写作质量: ⭐⭐⭐⭐ 论文结构清晰,动机阐述流畅,图示直观,但相关工作部分对CogAgent的对比实验设置说明可以更清楚
- 价值: ⭐⭐⭐⭐ 在高分辨率VLM这个实际痛点上提出了有效的效率-性能trade-off方案,40%计算量降低+性能提升是实用的工程价值