Multi-Head Low-Rank Attention (MLRA)¶
会议: ICLR 2026
arXiv: 2603.02188
代码: GitHub / HuggingFace
领域: LLM Efficiency / Attention Mechanism
关键词: KV Cache, 张量并行, 低秩注意力, 解码效率, Multi-Head Latent Attention
一句话总结¶
提出 Multi-Head Low-Rank Attention (MLRA),通过将 MLA 的单一 latent head 分解为多个可独立分片的 latent head,并对各分支注意力输出求和,实现原生 4-way 张量并行支持,在保持 SOTA 性能的同时获得 2.8× 的解码加速。
研究背景与动机¶
- 长上下文推理瓶颈:LLM 推理时,解码阶段需要在每一步从 HBM(高带宽内存)向 SRAM 反复传输 KV cache,数据搬运(而非计算)成为长上下文推理的主要延迟来源。
- MLA 的成功与局限:DeepSeek 提出的 Multi-Head Latent Attention (MLA) 通过将 KV cache 压缩为单一 latent head(每 token 仅需 \(4.5 d_h\)),显著减少了 KV cache 总量。但其单一 latent head 不可分片——在张量并行(TP)解码时,每个设备被迫冗余加载完整 KV cache。
- TP 下 MLA 的性能停滞:无论 TP 程度如何,MLA 的每设备 KV cache 加载量始终固定在 \(4.5 d_h\),TP 带来的权重分片优势被完全抵消。
- GLA-2 的部分解决:GLA-2 将 latent head 二等分为两个较小的 latent head,在 2-way TP 下降低为 \(2.5 d_h\),但 TP > 2 时同样无法进一步降低。
- 算术强度的重要性:解码效率的关键指标是算术强度(FLOPs/byte),需要在减少内存加载的同时保持高计算密度,从而将工作负载从内存受限推向计算受限。
- 方差不匹配问题:MLA 中 RoPE key 与 NoPE key 存在显著的方差不匹配(\(\text{Var}(K^{\text{RoPE}}) / \text{Var}(K^{\text{NoPE}}) \approx d/d_c\)),当 latent 维度远小于隐藏维度时问题尤为突出。
方法详解¶
整体框架¶
MLRA 的核心思路是:将 MLA 中不可分片的单一 latent head 显式分解为多个独立的 latent head,每个 latent head 独立进行 up-projection 生成 NoPE KV,然后对各分支的注意力输出求和。这种设计天然支持 4-way 张量并行。
MLRA 提供两个变体: - MLRA-2:2 个 latent head,每个 latent head 服务半数注意力头,产生 2 分支输出求和 - MLRA-4:4 个 latent head,每个 latent head 通过 up-projection 服务全部注意力头,产生 4 分支输出求和
关键设计¶
1. 块分解(Block Decomposition)¶
做什么:将 MLA 的 KV latent 矩阵 \(C^{KV} \in \mathbb{R}^{n \times d_c}\) 按通道划分为 4 个块 \(C_{:,(b)}^{KV}\),同时将 up-projection 矩阵 \(W^{UK}, W^{UV}\) 按行划分为对应的 4 个子块。
为什么:MLA 中每个头的 NoPE key/value 本质上等价于 4 个子块乘积之和(Eq. 2),将这一求和从 KV 计算移至注意力输出,使得每个子块可以独立计算注意力。
怎么做(以 MLRA-4 为例): $\(O_{:,i,:} = \sum_{b=0}^{3} \text{Softmax}\left(\tau Q_{:,i,:}^{\text{NoPE}} (C_{:,(b)}^{KV} W_{(b),(i)}^{UK})^\top + \tau Q_{:,i,:}^{\text{RoPE}} (K^{\text{RoPE}})^\top \right) (C_{:,(b)}^{KV} W_{(b),(i)}^{UV})\)$
每个分支 \(b\) 仅依赖 \(C_{:,(b)}^{KV}\)(大小为 \(d_h\)),可以分配给不同 TP 设备独立计算。
2. MLRA-2 的分组映射¶
做什么:沿用 GLA-2 的分组策略,将 latent head 二等分,前半注意力头用第一组 latent,后半注意力头用第二组 latent。
为什么:提供一种轻量级替代方案,在 2 分支求和中兼顾模型容量与效率。
怎么做:通过分组映射函数 \(\gamma(i)\) 决定第 \(i\) 个注意力头使用哪个 latent 组,每个头的输出是 2 个分支注意力的加和。
3. 方差校准(Variance Calibration)¶
做什么:对 query 和 KV latent state 施加缩放因子,并对多分支注意力输出进行归一化。
为什么:理论分析表明 NoPE key 的方差为 \(d_c \sigma_w^2\),而 RoPE key 的方差为 \(d \sigma_w^2\),分支拆分进一步加剧了方差不匹配。多分支求和还改变了输出的方差。
怎么做: - Latent state 缩放:\(C^Q \leftarrow \sqrt{d/d_c'} \cdot C^Q\),\(C^{KV} \leftarrow \sqrt{4d/d_c} \cdot C^{KV}\) - 输出缩放:MLRA-2 的输出除以 \(\sqrt{2}\),MLRA-4 的输出除以 \(2\)
4. 高效解码(TP-Friendly Decoding)¶
做什么:实现原生 4-way TP 解码,每个设备仅加载 \(1.5 d_h\) 的 KV cache。
为什么:4 个 latent block 可自然分配到 4 个 TP 设备上,每设备仅加载一个 latent block(\(d_h\))加 共享的 RoPE key(\(0.5 d_h\))。
怎么做:沿用 MLA 的 weight absorption 技巧将 up-projection 吸收进 query 侧,然后在每个设备上独立执行 MQA-style 解码,最后 all-reduce 求和。
损失函数 / 训练策略¶
- 初始化:输出投影参数使用零初始化(优于标准 \(\mathcal{N}(0, 0.02)\)),其余参数标准初始化
- 优化器:AdamW,\((\beta_1, \beta_2)=(0.9, 0.95)\),weight decay 0.1,梯度裁剪 1.0
- 学习率:peak \(1.6 \times 10^{-4}\),前 2000 步线性 warmup,之后余弦退火至 10%
- 训练规模:2.9B 参数,FineWeb-Edu-100B 数据集,98.3B token,上下文长度 2048,8 × H100 GPU
- 可选门控机制:在注意力输出投影前添加 gating,可进一步降低困惑度
实验关键数据¶
主实验¶
验证集困惑度(7 个数据集平均,越低越好):
| 方法 | Wikipedia | C4 | Pile | RefinedWeb | Cosmopedia | FineWeb | FineWeb-Edu | 平均 |
|---|---|---|---|---|---|---|---|---|
| MHA | 14.624 | 16.575 | 12.929 | 18.698 | 9.102 | 15.656 | 9.434 | 13.860 |
| GQA | 15.057 | 16.628 | 13.758 | 18.885 | 9.504 | 15.713 | 9.427 | 14.139 |
| MLA | 14.567 | 16.345 | 12.965 | 18.523 | 8.966 | 15.440 | 9.284 | 13.727 |
| GLA-2 | 14.605 | 16.323 | 13.225 | 18.509 | 9.118 | 15.424 | 9.249 | 13.779 |
| MLRA-4 | 14.407 | 16.286 | 13.124 | 18.398 | 8.937 | 15.361 | 9.193 | 13.672 |
零样本常识推理准确率(%):
| 方法 | ARC-E | ARC-C | OBQA | BoolQ | HellaSwag | Winogrande | PIQA | 平均 |
|---|---|---|---|---|---|---|---|---|
| MHA | 69.11 | 39.16 | 40.80 | 62.26 | 60.82 | 57.62 | 74.86 | 57.81 |
| GQA | 67.13 | 39.42 | 42.00 | 63.39 | 61.29 | 56.91 | 75.08 | 57.89 |
| MLA | 68.22 | 39.16 | 42.60 | 64.10 | 61.39 | 60.06 | 75.68 | 58.75 |
| MLRA-4 | 67.63 | 41.38 | 43.00 | 61.74 | 62.16 | 61.48 | 74.48 | 58.84 |
消融实验¶
门控机制对困惑度的影响(7 个数据集平均):
| 方法 | 无门控 | 有门控 | 提升 |
|---|---|---|---|
| GQA | 14.139 | 13.806 | -0.333 |
| MLA | 13.727 | 13.642 | -0.085 |
| GLA-2 | 13.779 | 13.701 | -0.078 |
| MLRA-2 | 13.804 | 13.651 | -0.153 |
| MLRA-4 | 13.672 | 13.621 | -0.051 |
其他消融发现: - 零初始化 vs \(\mathcal{N}(0, 0.02)\):零初始化在所有模型上均优于随机初始化 - 方差缩放:MLA 和 GLA-2 收益显著,MLRA-2 收益边际(因分支已天然减缓方差不匹配) - 加倍注意力头数:在参数总量不变的条件下加倍 GQA/MLA/GLA-2 的头数,均未带来改善反而损害性能
关键发现¶
- MLRA-4 全面最优:在困惑度(13.672)和零样本推理准确率(58.84%)上均超越所有基线,包括 MLA
- 2.8× 解码加速:相比 MLA,MLRA-4 在 128K-2M token 长上下文解码中实现稳定 2.8× 加速
- 1.05-1.26× 相比 GQA:在长上下文解码中超越 GQA,且差距随上下文长度增加而扩大
- TP=4 即可达到 1.5 \(d_h\):GQA 和 GTA 需要 8-way TP 才能达到类似的每设备 KV 加载量,MLRA 仅需 4-way
- 门控进一步提升:加入门控后 MLRA-4 困惑度降至 13.621,仍保持最优
亮点与洞察¶
- 优雅的数学动机:从 MLA 的块分解出发,将 KV 层面的求和移至注意力输出层面求和,数学上简洁且直觉清晰
- 方差分析扎实:从理论上推导了各组件的方差,给出了明确的校准策略,避免了凭经验调参
- 实用性极强:MLRA 与 MLA 共享相同的 KV cache 总量(每 token \(4.5 d_h\)),区别仅在于分布式解码时是否可分片,对现有 MLA 系统的迁移成本低
- 算术强度分析:MLRA-4 的算术强度约为 \(2h\)(MLA 也是 \(2h\)),说明在减少内存加载的同时不牺牲计算效率
- 完整的工程实现:基于 FlashAttention-3 实现 MLRA-4 kernel,并在真实 H100 集群上验证
局限性 / 可改进方向¶
- 仅在 2.9B 规模验证:未在 7B+ 或更大规模上实验,大模型上 MLRA 的优势是否保持有待验证
- 预训练数据单一:仅使用 FineWeb-Edu-100B,未在多语言或代码混合数据上验证
- Assumption 1 的局限:方差分析假设权重 i.i.d. 分布且与输入独立,训练过程中此假设不严格成立
- 固定 4-way 分解:MLRA-4 绑定了 4-way TP,对于 2-way 或 8-way TP 场景需要分别使用 MLRA-2 或进一步扩展
- 多分支求和的近似性:将求和从 softmax 内部移至外部本质上改变了注意力分布,虽然实验效果好但理论上并非等价变换
- 未评估指令微调和对齐后的效果:仅关注预训练,未覆盖 SFT/RLHF 阶段
相关工作与启发¶
- MLA (DeepSeek-V2/V3):MLRA 的直接前驱,通过 latent compression 压缩 KV cache,但不支持 TP 分片
- GQA:Grouped Query Attention,通过减少 KV head 数实现效率提升,但 KV cache 仍随头数线性增长
- GLA-2 (Zadouri et al., 2025):首个尝试拆分 MLA latent head 的工作,但仅支持 2-way TP
- TPA (Zhang et al., 2025):Tensor Product Attention,用共享 head 的线性组合构造 KV,TP 支持有限
- FlashMLA / FlashAttention-3:高效注意力 kernel,MLRA 的 kernel 基于 FlashAttention-3 实现
- LongCat (2025):首次观察到 RoPE key 的方差不匹配问题,MLRA 沿用并扩展了其缩放策略
- 启发:这种"将内部求和外提为多分支独立计算"的思路可能适用于其他需要 latent compression + TP 的场景(如 KV cache 量化、稀疏注意力等)
评分¶
- 新颖性: ⭐⭐⭐⭐ — 洞察精准,将 MLA 的"不可分片"问题通过分支分解优雅解决
- 实验充分度: ⭐⭐⭐⭐ — 多组消融+多数据集评估+解码速度/吞吐量测试,但仅 2.9B 规模
- 写作质量: ⭐⭐⭐⭐⭐ — 数学推导严谨,符号清晰,从背景到方法的逻辑链条完整
- 价值: ⭐⭐⭐⭐ — 直接解决 MLA 部署的实际痛点,对 DeepSeek 系列和大规模推理部署有重要实用价值