跳转至

Multi-Head Low-Rank Attention (MLRA)

会议: ICLR 2026
arXiv: 2603.02188
代码: GitHub / HuggingFace
领域: LLM Efficiency / Attention Mechanism
关键词: KV Cache, 张量并行, 低秩注意力, 解码效率, Multi-Head Latent Attention

一句话总结

提出 Multi-Head Low-Rank Attention (MLRA),通过将 MLA 的单一 latent head 分解为多个可独立分片的 latent head,并对各分支注意力输出求和,实现原生 4-way 张量并行支持,在保持 SOTA 性能的同时获得 2.8× 的解码加速。

研究背景与动机

  1. 长上下文推理瓶颈:LLM 推理时,解码阶段需要在每一步从 HBM(高带宽内存)向 SRAM 反复传输 KV cache,数据搬运(而非计算)成为长上下文推理的主要延迟来源。
  2. MLA 的成功与局限:DeepSeek 提出的 Multi-Head Latent Attention (MLA) 通过将 KV cache 压缩为单一 latent head(每 token 仅需 \(4.5 d_h\)),显著减少了 KV cache 总量。但其单一 latent head 不可分片——在张量并行(TP)解码时,每个设备被迫冗余加载完整 KV cache。
  3. TP 下 MLA 的性能停滞:无论 TP 程度如何,MLA 的每设备 KV cache 加载量始终固定在 \(4.5 d_h\),TP 带来的权重分片优势被完全抵消。
  4. GLA-2 的部分解决:GLA-2 将 latent head 二等分为两个较小的 latent head,在 2-way TP 下降低为 \(2.5 d_h\),但 TP > 2 时同样无法进一步降低。
  5. 算术强度的重要性:解码效率的关键指标是算术强度(FLOPs/byte),需要在减少内存加载的同时保持高计算密度,从而将工作负载从内存受限推向计算受限。
  6. 方差不匹配问题:MLA 中 RoPE key 与 NoPE key 存在显著的方差不匹配(\(\text{Var}(K^{\text{RoPE}}) / \text{Var}(K^{\text{NoPE}}) \approx d/d_c\)),当 latent 维度远小于隐藏维度时问题尤为突出。

方法详解

整体框架

MLRA 的核心思路是:将 MLA 中不可分片的单一 latent head 显式分解为多个独立的 latent head,每个 latent head 独立进行 up-projection 生成 NoPE KV,然后对各分支的注意力输出求和。这种设计天然支持 4-way 张量并行。

MLRA 提供两个变体: - MLRA-2:2 个 latent head,每个 latent head 服务半数注意力头,产生 2 分支输出求和 - MLRA-4:4 个 latent head,每个 latent head 通过 up-projection 服务全部注意力头,产生 4 分支输出求和

关键设计

1. 块分解(Block Decomposition)

做什么:将 MLA 的 KV latent 矩阵 \(C^{KV} \in \mathbb{R}^{n \times d_c}\) 按通道划分为 4 个块 \(C_{:,(b)}^{KV}\),同时将 up-projection 矩阵 \(W^{UK}, W^{UV}\) 按行划分为对应的 4 个子块。

为什么:MLA 中每个头的 NoPE key/value 本质上等价于 4 个子块乘积之和(Eq. 2),将这一求和从 KV 计算移至注意力输出,使得每个子块可以独立计算注意力。

怎么做(以 MLRA-4 为例): $\(O_{:,i,:} = \sum_{b=0}^{3} \text{Softmax}\left(\tau Q_{:,i,:}^{\text{NoPE}} (C_{:,(b)}^{KV} W_{(b),(i)}^{UK})^\top + \tau Q_{:,i,:}^{\text{RoPE}} (K^{\text{RoPE}})^\top \right) (C_{:,(b)}^{KV} W_{(b),(i)}^{UV})\)$

每个分支 \(b\) 仅依赖 \(C_{:,(b)}^{KV}\)(大小为 \(d_h\)),可以分配给不同 TP 设备独立计算。

2. MLRA-2 的分组映射

做什么:沿用 GLA-2 的分组策略,将 latent head 二等分,前半注意力头用第一组 latent,后半注意力头用第二组 latent。

为什么:提供一种轻量级替代方案,在 2 分支求和中兼顾模型容量与效率。

怎么做:通过分组映射函数 \(\gamma(i)\) 决定第 \(i\) 个注意力头使用哪个 latent 组,每个头的输出是 2 个分支注意力的加和。

3. 方差校准(Variance Calibration)

做什么:对 query 和 KV latent state 施加缩放因子,并对多分支注意力输出进行归一化。

为什么:理论分析表明 NoPE key 的方差为 \(d_c \sigma_w^2\),而 RoPE key 的方差为 \(d \sigma_w^2\),分支拆分进一步加剧了方差不匹配。多分支求和还改变了输出的方差。

怎么做: - Latent state 缩放:\(C^Q \leftarrow \sqrt{d/d_c'} \cdot C^Q\)\(C^{KV} \leftarrow \sqrt{4d/d_c} \cdot C^{KV}\) - 输出缩放:MLRA-2 的输出除以 \(\sqrt{2}\),MLRA-4 的输出除以 \(2\)

4. 高效解码(TP-Friendly Decoding)

做什么:实现原生 4-way TP 解码,每个设备仅加载 \(1.5 d_h\) 的 KV cache。

为什么:4 个 latent block 可自然分配到 4 个 TP 设备上,每设备仅加载一个 latent block(\(d_h\))加 共享的 RoPE key(\(0.5 d_h\))。

怎么做:沿用 MLA 的 weight absorption 技巧将 up-projection 吸收进 query 侧,然后在每个设备上独立执行 MQA-style 解码,最后 all-reduce 求和。

损失函数 / 训练策略

  • 初始化:输出投影参数使用零初始化(优于标准 \(\mathcal{N}(0, 0.02)\)),其余参数标准初始化
  • 优化器:AdamW,\((\beta_1, \beta_2)=(0.9, 0.95)\),weight decay 0.1,梯度裁剪 1.0
  • 学习率:peak \(1.6 \times 10^{-4}\),前 2000 步线性 warmup,之后余弦退火至 10%
  • 训练规模:2.9B 参数,FineWeb-Edu-100B 数据集,98.3B token,上下文长度 2048,8 × H100 GPU
  • 可选门控机制:在注意力输出投影前添加 gating,可进一步降低困惑度

实验关键数据

主实验

验证集困惑度(7 个数据集平均,越低越好)

方法 Wikipedia C4 Pile RefinedWeb Cosmopedia FineWeb FineWeb-Edu 平均
MHA 14.624 16.575 12.929 18.698 9.102 15.656 9.434 13.860
GQA 15.057 16.628 13.758 18.885 9.504 15.713 9.427 14.139
MLA 14.567 16.345 12.965 18.523 8.966 15.440 9.284 13.727
GLA-2 14.605 16.323 13.225 18.509 9.118 15.424 9.249 13.779
MLRA-4 14.407 16.286 13.124 18.398 8.937 15.361 9.193 13.672

零样本常识推理准确率(%)

方法 ARC-E ARC-C OBQA BoolQ HellaSwag Winogrande PIQA 平均
MHA 69.11 39.16 40.80 62.26 60.82 57.62 74.86 57.81
GQA 67.13 39.42 42.00 63.39 61.29 56.91 75.08 57.89
MLA 68.22 39.16 42.60 64.10 61.39 60.06 75.68 58.75
MLRA-4 67.63 41.38 43.00 61.74 62.16 61.48 74.48 58.84

消融实验

门控机制对困惑度的影响(7 个数据集平均)

方法 无门控 有门控 提升
GQA 14.139 13.806 -0.333
MLA 13.727 13.642 -0.085
GLA-2 13.779 13.701 -0.078
MLRA-2 13.804 13.651 -0.153
MLRA-4 13.672 13.621 -0.051

其他消融发现: - 零初始化 vs \(\mathcal{N}(0, 0.02)\):零初始化在所有模型上均优于随机初始化 - 方差缩放:MLA 和 GLA-2 收益显著,MLRA-2 收益边际(因分支已天然减缓方差不匹配) - 加倍注意力头数:在参数总量不变的条件下加倍 GQA/MLA/GLA-2 的头数,均未带来改善反而损害性能

关键发现

  1. MLRA-4 全面最优:在困惑度(13.672)和零样本推理准确率(58.84%)上均超越所有基线,包括 MLA
  2. 2.8× 解码加速:相比 MLA,MLRA-4 在 128K-2M token 长上下文解码中实现稳定 2.8× 加速
  3. 1.05-1.26× 相比 GQA:在长上下文解码中超越 GQA,且差距随上下文长度增加而扩大
  4. TP=4 即可达到 1.5 \(d_h\):GQA 和 GTA 需要 8-way TP 才能达到类似的每设备 KV 加载量,MLRA 仅需 4-way
  5. 门控进一步提升:加入门控后 MLRA-4 困惑度降至 13.621,仍保持最优

亮点与洞察

  • 优雅的数学动机:从 MLA 的块分解出发,将 KV 层面的求和移至注意力输出层面求和,数学上简洁且直觉清晰
  • 方差分析扎实:从理论上推导了各组件的方差,给出了明确的校准策略,避免了凭经验调参
  • 实用性极强:MLRA 与 MLA 共享相同的 KV cache 总量(每 token \(4.5 d_h\)),区别仅在于分布式解码时是否可分片,对现有 MLA 系统的迁移成本低
  • 算术强度分析:MLRA-4 的算术强度约为 \(2h\)(MLA 也是 \(2h\)),说明在减少内存加载的同时不牺牲计算效率
  • 完整的工程实现:基于 FlashAttention-3 实现 MLRA-4 kernel,并在真实 H100 集群上验证

局限性 / 可改进方向

  1. 仅在 2.9B 规模验证:未在 7B+ 或更大规模上实验,大模型上 MLRA 的优势是否保持有待验证
  2. 预训练数据单一:仅使用 FineWeb-Edu-100B,未在多语言或代码混合数据上验证
  3. Assumption 1 的局限:方差分析假设权重 i.i.d. 分布且与输入独立,训练过程中此假设不严格成立
  4. 固定 4-way 分解:MLRA-4 绑定了 4-way TP,对于 2-way 或 8-way TP 场景需要分别使用 MLRA-2 或进一步扩展
  5. 多分支求和的近似性:将求和从 softmax 内部移至外部本质上改变了注意力分布,虽然实验效果好但理论上并非等价变换
  6. 未评估指令微调和对齐后的效果:仅关注预训练,未覆盖 SFT/RLHF 阶段

相关工作与启发

  • MLA (DeepSeek-V2/V3):MLRA 的直接前驱,通过 latent compression 压缩 KV cache,但不支持 TP 分片
  • GQA:Grouped Query Attention,通过减少 KV head 数实现效率提升,但 KV cache 仍随头数线性增长
  • GLA-2 (Zadouri et al., 2025):首个尝试拆分 MLA latent head 的工作,但仅支持 2-way TP
  • TPA (Zhang et al., 2025):Tensor Product Attention,用共享 head 的线性组合构造 KV,TP 支持有限
  • FlashMLA / FlashAttention-3:高效注意力 kernel,MLRA 的 kernel 基于 FlashAttention-3 实现
  • LongCat (2025):首次观察到 RoPE key 的方差不匹配问题,MLRA 沿用并扩展了其缩放策略
  • 启发:这种"将内部求和外提为多分支独立计算"的思路可能适用于其他需要 latent compression + TP 的场景(如 KV cache 量化、稀疏注意力等)

评分

  • 新颖性: ⭐⭐⭐⭐ — 洞察精准,将 MLA 的"不可分片"问题通过分支分解优雅解决
  • 实验充分度: ⭐⭐⭐⭐ — 多组消融+多数据集评估+解码速度/吞吐量测试,但仅 2.9B 规模
  • 写作质量: ⭐⭐⭐⭐⭐ — 数学推导严谨,符号清晰,从背景到方法的逻辑链条完整
  • 价值: ⭐⭐⭐⭐ — 直接解决 MLA 部署的实际痛点,对 DeepSeek 系列和大规模推理部署有重要实用价值