QKV Projections Require a Fraction of Their Memory¶
会议: ICLR 2026
arXiv: 2506.02939
代码: 无
领域: 模型压缩
关键词: 训练内存压缩, 注意力机制, 矩阵乘法近似, 激活压缩, LLM训练
一句话总结¶
提出 PAMM(Point-Approximate Matrix Multiplication),一种激活压缩技术,通过随机选取少量代表性 token 来近似 QKV 投影层激活,实现高达 512× 压缩率且不影响模型性能。
研究背景与动机¶
LLM 训练中,注意力层的 QKV 投影占用大量内存:输入 \(X\) 需要在前向过程中保存以用于反向传播(计算 \(\nabla W = X^\top \cdot \nabla Z\))。这部分内存可占注意力块总峰值 GPU 内存的 20%。
现有内存优化方法的不足: - 高效注意力(FlashAttention 等):优化缩放点积本身,未涉及线性投影 - 低秩方法(CompAct 等):沿隐藏维度压缩,但序列维度的冗余更大 - 优化器状态压缩:不随 batch size 和序列长度扩展
核心洞察:序列维度存在巨大冗余。训练 batch 中的 token 数量 \(b = BL\)(如 16384)远大于隐藏维度 \(n\)(如 2048),\(\text{rank}(X) \leq n\),理论上仅需 \(n\) 个基向量即可表示 \(X\),压缩比可达 8×。
方法详解¶
整体框架¶
PAMM 分两阶段工作:(1) 前向时将 \(X\) 压缩为少量生成点和辅助信息;(2) 反向时用压缩表示近似计算梯度 \(\nabla W\)。
关键设计¶
-
激活压缩 (Compression Stage):
- 从 \(X \in \mathbb{R}^{b \times n}\) 中随机采样 \(k = r \cdot b\) 行作为生成点 \(C \in \mathbb{R}^{k \times n}\)
- 对每个点 \(A_i\),选择最佳生成点:\(f(i) = \arg\max_j |\text{csim}(A_i, C_j)|\)(Lemma 1)
- 计算缩放系数:\(\tilde{A}_i = \alpha(i, f(i)) \cdot C_{f(i)}\),其中 \(\alpha = \frac{\langle A_i, C_j \rangle}{\|C_j\|_2^2}\)
- 邻域条件:\(\|A_i - \tilde{A}_i\|_2 \leq \varepsilon \|A_i\|_2\),不满足则丢弃
-
近似矩阵乘法 (Approximate Multiplication):
- 不重建完整 \(\tilde{A}\),而是先聚合 \(\tilde{B}_j = \sum_{i:f(i)=j} \alpha_i B_i\)
- 计算 \(\tilde{O} = C^\top \tilde{B}\),维度从 \(b \times n\) 降为 \(k \times n\)
- 引入归一化因子 \(\beta = \frac{b}{b-\eta}\) 保证无偏估计 \(\mathbb{E}[\tilde{O}] = O\)
-
理论保证:
- Lemma 2(\(k\) 的充分条件):\(k > \frac{b}{n_{\min}} \ln(\frac{b}{\delta})\),仅需对数级别的生成点
- 近似误差上界:\(\|O - \tilde{O}\|_F^2 \leq \|B\|_2^2 (\varepsilon^2 \|A_\mathcal{I}\|_F^2 + \|A_{\bar{\mathcal{I}}}\|_F^2)\)
- 实践中 \(\varepsilon \to \infty\)(不使用邻域约束)效果最好
损失函数 / 训练策略¶
- PAMM 仅修改 QKV 投影的反向传播,前向和其他层梯度不受影响
- 与 FlashAttention、梯度检查点、LoRA 完全兼容
- 实验中压缩比 \(r\) 低至 \(1/512\)
- 微调场景中甚至可以用 \(k=1\)(仅一个生成点)
实验关键数据¶
预训练实验(LLaMA on C4)¶
| 模型 | PAMM r | 验证 PPL | QKV 内存 (MB) | 内存减少 |
|---|---|---|---|---|
| LLaMA-60M | 无 PAMM | 31.8 | 432 | - |
| LLaMA-60M | 1/512 | 31.6 | 0.85 | >99% |
| LLaMA-350M | 无 PAMM | 18.7 | 1,296 | - |
| LLaMA-350M | 1/512 | 18.5 | 2.53 | >99% |
| LLaMA-1B | 无 PAMM | 15.1 | 2,592 | - |
| LLaMA-1B | 1/512 | 15.0 | 5.06 | >99% |
微调实验(RoBERTa-base on GLUE)¶
| 方法 | QKV 内存 (MB) | GLUE 平均 | 内存减少 |
|---|---|---|---|
| Full Fine-Tuning | 288 | 86.28 | - |
| PAMM r=1/128 | 6.75 | 86.11 | 97.7% |
| PAMM r=1/256 | 3.37 | 86.18 | 98.8% |
吞吐量分析(LLaMA-1B)¶
| 阶段 | 基线 (tok/s) | PAMM (tok/s) | 吞吐量降低 |
|---|---|---|---|
| 前向 | 247.6K | 235.4K | 4.92% |
| 反向 | 141.9K | 138.3K | 2.53% |
| 总计 | 88.4K | 85.2K | 3.61% |
关键发现¶
- 512× 压缩下 PPL 不降反升(大模型更明显),说明冗余 token 可能影响训练
- 随模型增大,吞吐量损失从 19.7%(60M)降至 2.1%(7B),大模型更实用
- PAMM 在所有 batch size 和序列长度配置下均表现稳定
- 对比 CompAct(沿隐藏维度压缩):PAMM 在高压缩比下性能显著更好
亮点与洞察¶
- 洞察深刻:序列维度冗余远大于隐藏维度冗余,这是高压缩比的根本原因
- 极其简单有效:随机选取生成点就足够,无需复杂聚类
- 理论严谨:Lemma 1/2 提供了算法设计的理论指导
- 与 FlashAttention 等完全正交,可直接叠加使用
- 惊喜发现:高压缩比下 PPL 反而略有改善,暗示正则化效应
局限与展望¶
- 仅应用于 QKV 投影,未探索 FFN 层的激活压缩
- 邻域条件参数 \(\varepsilon\) 的最优设置为 \(\infty\)(即不使用),理论解释不充分
- 额外计算(余弦相似度矩阵 + argmax)对小模型影响较大
- 未在分布式训练(多节点)场景下验证
相关工作与启发¶
- 与 CompAct 的关键区别:PAMM 沿序列维度压缩(冗余更大),CompAct 沿隐藏维度
- 与梯度检查点的关系:互补——梯度检查点减少存储的层数,PAMM 减少每层存储量
- 启示:训练内存优化不应只关注优化器状态和注意力机制,激活内存同样重要
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ 发现序列维度冗余的新方向,方法极简高效
- 实验充分度: ⭐⭐⭐⭐⭐ 预训练/微调/吞吐量/消融全面覆盖
- 写作质量: ⭐⭐⭐⭐⭐ 理论和实验结合好,图示清晰
- 价值: ⭐⭐⭐⭐⭐ 实际可用于 LLM 训练的内存优化工具
相关论文¶
- [ICLR 2026] LightMem: Lightweight and Efficient Memory-Augmented Generation
- [ICCV 2025] A Good Teacher Adapts Their Knowledge for Distillation
- [ACL 2026] Mem^p: Exploring Agent Procedural Memory
- [AAAI 2026] Sharp Eyes and Memory for VideoLLMs: Information-Aware Visual Token Pruning for Efficient and Reliable VideoLLM Reasoning
- [NeurIPS 2025] Beyond Higher Rank: Token-wise Input-Output Projections for Efficient Low-Rank Adaptation