跳转至

VSA: Faster Video Diffusion with Trainable Sparse Attention

会议: NeurIPS 2025
arXiv: 2505.13389
代码: https://github.com/hao-ai-lab/FastVideo
领域: 视频生成 / 注意力加速
关键词: 稀疏注意力, 视频扩散Transformer, 端到端训练, 关键token预测, 硬件对齐

一句话总结

提出 VSA (Video Sparse Attention),一种端到端可训练的硬件对齐稀疏注意力机制,通过粗粒度阶段(cube 池化预测关键 token)和细粒度阶段(在预测的块稀疏区域执行 token 级注意力)的层次化设计,在视频 DiT 的训练和推理中同时实现加速:从头预训练实现 2.53× 训练 FLOPs 减少且无质量损失,适配 Wan2.1-1.3B 实现注意力 6× 加速和端到端推理从 31s 降至 18s。

研究背景与动机

视频扩散 Transformer (DiT) 面临严峻的计算瓶颈:一个仅 5 秒的 720p 视频在潜空间展开后就超过 100K tokens,二次复杂度的 3D 注意力消耗了绝大部分计算。尽管研究表明 DiT 注意力矩阵天然是稀疏的(大部分条目接近零),但现有方法几乎都将稀疏性当作"后处理加速"——先用全注意力训练完,再在推理时替换为固定或 profile 出的稀疏掩码。

这种"先密后稀"的范式存在两个根本问题:

训练成本不变:绝大部分训练计算量无法受益于稀疏性

训练-测试不匹配:模型在密集上下文中学习参数,却在稀疏上下文中评估——质量上限被密集模型锁死,稀疏度推高后质量明显下降

核心挑战是一个"鸡生蛋蛋生鸡"问题:要精确识别关键 token 需要计算完整注意力矩阵,但一旦算了就失去了稀疏节省的意义。反之,用廉价启发式方法又可能漏掉高权重区域。而且,任何稀疏实现必须服从 GPU 块稀疏计算的硬件约束,否则理论节省无法转化为实际加速。

VSA 的核心思路:用可学习的轻量粗粒度注意力预测关键 token 位置,再在预测的块稀疏区域执行细粒度 token 注意力,两者端到端联合训练。

方法详解

整体框架

VSA 采用层次化的粗-细 (Coarse-Fine) 两阶段注意力设计。首先将视频潜变量按 (4,4,4) 的 cube 划分,每个 cube 映射到 GPU SM 的一个 tile(block size = 64)。粗粒度阶段在 cube 级别进行低成本的全注意力来预测关键位置;细粒度阶段仅在被选中的 cube 上执行 token 级块稀疏注意力。两阶段的输出通过可学习的门控向量融合为最终输出。

关键设计

  1. 粗粒度阶段(Coarse Stage):对每个 (4,4,4) cube 内的 token 做均值池化,得到压缩 64 倍的 cube 级表示 Q_c, K_c, V_c。在这个短序列上计算密集注意力得到 cube 级注意力分数 A_c 和输出 O_c。然后对 A_c 按行取 Top-K(默认 K=32),生成块稀疏掩码 M。由于粗粒度阶段在 64 倍压缩的序列上操作,计算量不到总注意力的 0.2%。关键创新是:粗粒度阶段的输出 O_c 不仅用于指导稀疏选择,还直接参与最终输出——实验证明这对维持全局上下文至关重要。

  2. 细粒度阶段(Fine Stage):基于粗粒度阶段生成的块稀疏掩码 M,对原始 token 级 Q, K, V 执行稀疏注意力,仅计算被选中的 cube 对应的 block。由于掩码天然是块结构的(粗粒度阶段在 cube 级操作),完美对齐 FlashAttention 的块稀疏计算格式,无需额外的 mask-to-index 转换。

  3. 门控融合:将输入隐状态通过线性投影得到粗/细两组门控向量 G_c 和 G_f,最终输出 O = O_c ⊙ G_c + O_f ⊙ G_f。这使模型能学习在不同头和层中动态平衡全局概览和局部细节的比重。

  4. Tile 大小的权衡:小 tile(如 16)允许更精细的稀疏粒度,但 GPU 吞吐低;大 tile(如 256)arithmetic intensity 高,但稀疏模式粗糙。系统性实验表明 tile=64(即 cube=(4,4,4))在质量和效率间取得最佳平衡。

  5. 稀疏适配与蒸馏 (Sparse Adaptation):将预训练的全注意力模型转换为 VSA 时采用渐进退火策略——初始将 G_c 权重设为零、移除 G_f(等效全注意力),然后逐步降低 K 值到目标稀疏度。还首次实现了稀疏注意力与蒸馏的兼容 (Sparse-Distill),学生模型同时作为 few-step 和 sparse 生成器,达到 50.9× 加速。

损失函数 / 训练策略

VSA 使用标准 Flow Matching 损失端到端训练。GPU 内核方面:细粒度阶段使用 ThunderKittens 实现的块稀疏注意力内核(tile=64 下达到 FlashAttention3 85% 的 MFU);粗粒度阶段将 softmax、Top-K 选择和 mask-to-index 转换融合为单个内核。稀疏适配仅需约 4000 步微调(学习率 1e-5)。预训练实验总计约 90k H200 GPU 小时。

实验关键数据

主实验

从头预训练的 Scaling 实验(60M~1.4B 参数,最高 4×10²¹ FLOPs):

模型规模 序列长度 方法 注意力 FLOPs 减少 训练 FLOPs 减少 损失差异
120M 16K Full Attention - - baseline
120M 16K VSA (87.5% sparse) ~8× 2.53× 几乎无差
410M 16K Full Attention - - baseline
410M 16K VSA (87.5% sparse) ~8× 2.53× 几乎相同
60M~1.4B 16K VSA vs Full - 2.53× 帕累托前沿一致更优

Wan2.1 适配实验:

模型 方法 VBench Quality↑ VBench Semantic↑ VBench Total↑ 推理时间
Wan-1.3B 原始 (full attn) 83.71% 77.98% 82.56% 31s
Wan-1.3B Full finetuned 84.07% 81.85% 83.63% 31s
Wan-1.3B VSA finetuned 83.60% 79.47% 82.77% 18s
Wan-14B 原始 - - baseline 1274s
Wan-14B VSA - - 人评相当 576s

消融实验

实验 配置 Loss 说明
Exp 1 Compress KV 0.14282 KV 池化,过于粗糙
Exp 2 Spatial-Temporal 0.13034 经典交替注意力,过训后劣于 Full
Exp 5 Full Attention 0.12703 基线
Exp 6 VSA 0.12687 超过全注意力!
Exp 7 固定 Local 模式 0.13330 固定模式不如 data-dependent
Exp 8 Fine only (无 O_c) 0.13296 缺失粗粒度全局信息
Exp 10 Coarse + Fine 0.13162 最简设计最优
Exp 11 C + F + Local 0.13194 额外 Local 无收益
Exp 17 tile=64 0.13162 效率-质量最佳平衡
Exp 18 tile=16 0.13155 质量略好但速度慢 2.26×

关键发现

  • 可训练稀疏可以反超全注意力:在扩展训练后,VSA 在相同 FLOPs 下达到比全注意力更优的帕累托前沿——这是首次在 DiT 上的严格缩放验证
  • 数据驱动的动态稀疏远优于固定模式:Spatial-Temporal、滑窗等固定模式在计算最优预算下看似更好,但扩展训练后被全注意力反超——只有 VSA 始终保持优势
  • 全局信息必要但局部先验不必要:粗粒度阶段的输出 O_c 对结果至关重要,但额外的 Local 窗口注意力无额外收益
  • 关键 token 预测精度高:Top-32 选择在大多数层/时间步覆盖 60%~90% 的注意力权重(随机基线仅 8%)
  • 注意力模式高度动态:可视化显示同一层不同头展现截然不同的模式——有的全局、有的局部、有的混合——进一步印证了固定模式不可行
  • 最优 Top-K 依赖序列长度和训练预算——更多训练计算需要更高 K,暗示"稀疏层面的 scaling law"值得研究

亮点与洞察

  • 首个在 DiT 上通过严格 scaling 实验证明可训练稀疏注意力优于全注意力的工作——意义重大,可能改变视频 DiT 的默认设计
  • "粗粒度预测+细粒度执行"的两阶段设计解决了鸡蛋问题:不需要算完整注意力就能找到关键 token
  • 粗粒度阶段输出直接参与最终输出是关键:与 MoBA、BiFormer 等仅用粗粒度做索引引导不同
  • Sparse-Distill 的先驱性:首次证明稀疏注意力与蒸馏兼容,同时实现 few-step + sparse acceleration,达到 50.9× 加速
  • 彻底的消融实验设计:90k H200 小时的系统性实验覆盖了 tile size、pooling 方式、局部先验、稀疏度等所有设计参数

局限与展望

  • Cube 大小固定为 (4,4,4),要求视频潜变量各维度可被 4 整除——限制了兼容的分辨率集合
  • 最优稀疏度(Top-K 值)的确定仍是开放问题——可能需要将稀疏性作为 scaling law 的额外维度
  • 粗粒度阶段虽然 FLOPs 可忽略,但在短序列下延迟占比仍达 14%——内核优化有提升空间
  • 长序列 scaling 实验有限——最大仅在 16K 序列上预训练,更长序列的行为待探索
  • 注意力模式的逐层/逐时间步自适应 Top-K 是明确的改进方向

相关工作与启发

  • NSA 和 MoBA 在 LLM 中开创了可训练稀疏注意力——VSA 将这一思路适配到双向 3D 视频注意力
  • STA (Sliding Tile Attention) 和 Sparge Attention 是推理时后处理方法——VSA 证明了训练时引入稀疏性的本质优势
  • DSV 也探索了训练时稀疏但采用多阶段 profiler 设计——VSA 的端到端训练更简洁
  • 核心启发:视频 DiT 的注意力瓶颈比 LLM 更严峻(因为序列更长且训练/推理全程密集),更迫切需要原生稀疏设计

评分

  • 新颖性: ⭐⭐⭐⭐⭐
  • 实验充分度: ⭐⭐⭐⭐⭐
  • 写作质量: ⭐⭐⭐⭐⭐
  • 价值: ⭐⭐⭐⭐⭐

相关论文