跳转至

Decoder-Hybrid-Decoder Architecture for Efficient Reasoning with Long Generation

会议: NeurIPS 2025
arXiv: 2507.06607
代码: github.com/microsoft/ArchScale
领域: LLM效率
关键词: 混合架构, SSM, 解码效率, 长序列生成, KV缓存

一句话总结

SambaY 提出 Gated Memory Unit(GMU)用于跨层共享 SSM 的 token 混合表示,将 YOCO 的 cross-decoder 中一半的 cross-attention 层替换为轻量级 GMU,在保持线性预填充复杂度和长上下文检索能力的同时,大幅提升解码效率——最终产品 Phi4-mini-Flash-Reasoning (3.8B) 在推理任务上超越 Phi4-mini-Reasoning,且在 2K 提示 + 32K 生成场景下实现高达 10× 的解码吞吐提升。

研究背景与动机

  1. 领域现状:SSM/RNN(如 Mamba)与 Transformer 的混合架构(如 Samba、YOCO)已证明可以在保持性能的同时显著提升推理效率。YOCO 通过只缓存一次 KV 的 decoder-decoder 结构实现线性预填充复杂度。
  2. 现有痛点:YOCO 的 cross-decoder 仍使用完整的 cross-attention 层,在生成阶段(而非预填充阶段)的注意力内存 I/O 代价仍为 \(O(d_{kv} \cdot N)\)。对于现代推理 LLM 生成的超长 CoT(如 32K tokens),这个开销成为新的瓶颈。
  3. 核心矛盾:SSM 层间天然存在可共享的隐藏表示(token mixing 的中间结果),但此前没有工作探索如何利用这种跨层表示共享来减少解码时的内存 I/O。
  4. 本文要解决什么? 如何在不损失长上下文检索能力的前提下,减少 cross-decoder 中注意力层的内存 I/O 开销?
  5. 切入角度:设计 GMU(Gated Memory Unit)——一种简单的门控机制,让 cross-decoder 的部分层直接复用 self-decoder 中最后一个 SSM 层的 token mixing 输出 \(M^{(l')}\),通过当前层输入的门控信号进行通道级别的重新校准。
  6. 核心idea一句话:用 GMU 替换一半 cross-attention 层,将解码时的内存 I/O 从线性 \(O(d_{kv} \cdot N)\) 降为常数 \(O(d_h)\),同时通过门控实现对先前 token mixing 的细粒度重新加权。

方法详解

整体框架

SambaY 采用 decoder-hybrid-decoder 架构:(1) Self-decoder(前半层)使用 Samba(Mamba + SWA 交替),最后有一个 full attention 层生成 KV cache;(2) Cross-decoder(后半层)交替使用 cross-attention(复用 KV cache)和 GMU(复用 SSM 的 token mixing 输出)。预填充阶段只需推理 self-decoder(线性复杂度),解码阶段 GMU 层只需 \(O(d_h)\) 的常数内存 I/O。

关键设计

  1. Gated Memory Unit (GMU):
  2. 做什么:用轻量级门控机制替代 cross-attention,实现跨层表示共享
  3. 核心思路:\(Y^l = (M^{(l')} \odot \sigma(X^l W_1^T)) W_2\),其中 \(M^{(l')}\) 是先前 SSM 层的 token mixing 输出,\(X^l\) 是当前层输入,\(\sigma\) 是 SiLU 激活。门控信号将 2D 的 token mixing 矩阵 \(A^{(l')}\) 提升为 3D 张量 \(\tilde{A}_{ijk} = G_{ik}^{(l)} A_{ij}^{(l')}\),实现通道级别的 token mixing 重新加权
  4. 设计动机:解码时 SSM 只需维护一个常数大小的状态 \(\mathbf{m} \in \mathbb{R}^{d_h \times d_h}\),避免了 cross-attention 对整个 KV cache 的线性 I/O。参数量和计算量都远小于标准注意力层

  5. μP++ 超参数缩放法则:

  6. 做什么:为深度和宽度同时缩放提供统一的超参数转移方案
  7. 核心思路:集成 μP(宽度缩放)+ Depth-μP(深度缩放: 学习率 \(\eta \propto 1/d\),残差分支输出除以 \(\sqrt{2d}\))+ 对 vector-like 参数使用零权重衰减
  8. 设计动机:标准 μP 在大规模训练(600B tokens)时会出现训练不稳定(NaN loss),零权重衰减解决了这个问题

  9. 等参数方程(Iso-Parametric Equation):

  10. 做什么:公平比较不同架构的缩放行为
  11. 核心思路:通过建立不同架构的参数数量方程,求解各架构对应的 aspect ratio \(\alpha\),使得不同架构在相同参数量下拥有相同的深度,从而 KV cache 大小也一致
  12. 设计动机:直接调整深度来匹配参数量会改变 KV cache 大小,导致推理时间比较不公平

损失函数 / 训练策略

标准语言模型交叉熵损失。Phi4-mini-Flash 在 5T token 上预训练,使用标准参数化(受限于资源未用 μP++)。训练中遇到严重的 loss divergence,通过 FP32 上转和 attention dropout 缓解。推理模型通过 SFT + DPO 蒸馏(无 RL)。

实验关键数据

主实验

推理性能(Phi4-mini-Flash-Reasoning vs Phi4-mini-Reasoning,均为 3.8B)

基准 Phi4-mini-Reasoning Phi4-mini-Flash-Reasoning
AIME24 (Pass@1, avg 64) 48.13 52.29
AIME25 (Pass@1, avg 64) 31.77 33.59
Math500 (Pass@1, avg 8) 91.20 92.45
GPQA Diamond (Pass@1, avg 8) 44.51 45.08
解码吞吐(2K prompt + 32K gen) ~10×

缩放实验(FLOPs scaling, 1B-3.4B 模型)

架构 不可约损失 \(C\) 说明
Transformer++ (μP++) 0.64 最高不可约损失
Samba+YOCO (μP++) 0.60 次优
SambaY (μP++) 0.58 最低不可约损失,缩放潜力最大

消融实验

配置 PB-32K 短任务平均 说明
SambaY (SWA=256) 78.13 52.16 最佳平衡
MambaY (无SWA) 12.50 51.87 无局部注意力导致检索崩溃
SambaY-2 (Mamba→Mamba-2) 40.63 51.00 标量遗忘门损失位置信息
SambaY-MLP (GMU门控MLP表示) 64.84 52.65 短任务最优但长上下文弱
SambaY-AA (全GMU无cross-attn) 46.88 52.06 cross-attention 不可或缺

关键发现

  • Phi4-mini-Flash-Reasoning 不用 RL 就超越了有 RL 阶段的 Phi4-mini-Reasoning,说明架构改进可以补偿训练流程的差异
  • SambaY 的不可约损失(0.58)显著低于 Transformer++(0.64),意味着在无限计算资源下 SambaY 能达到更低的损失
  • 小 SWA 窗口(256)就足以让 SambaY 达到良好的长上下文检索能力,不需要大窗口
  • 无位置编码(NoPE)的混合架构可以零样本外推到 2× 训练长度
  • GMU 门控 SSM 表示在长上下文检索上显著优于门控注意力或 MLP 表示

亮点与洞察

  • GMU 的数学解释很优雅:门控操作等价于将 2D token mixing 矩阵提升为 3D 张量 \(\tilde{A}_{ijk}\),这为理解跨层表示共享提供了清晰的理论框架
  • "不需要 RL 就超越 RL 模型" 这一结果非常有说服力:说明高效架构带来的推理质量提升可能被低估了——更快的推理 → 可以在相同时间内生成更长/更多 CoT → 间接提升推理质量
  • μP++ 的工程贡献显著:解决了实际大规模训练中的不稳定问题,对社区有直接实用价值

局限性 / 可改进方向

  • 仍保留一个 full attention 层,解码复杂度仍为线性而非常数
  • 在 Phi4-mini-Flash 的大规模训练中遇到严重的 loss divergence,目前的解决方案(label smoothing + attention dropout)是启发式的
  • 未与 RL 训练的推理模型进行公平比较(Flash-Reasoning 只用了 SFT+DPO)
  • 未对优化超参数进行针对性搜索,可能还有提升空间
  • Differential Attention 的 vLLM 实现使用朴素的四次 FlashAttention 调用,效率未充分优化

相关工作与启发

  • vs YOCO: SambaY 在 YOCO 的 cross-decoder 中用 GMU 替换一半 cross-attention,将解码时的内存 I/O 从线性降为常数。缩放实验显示 SambaY 的不可约损失更低
  • vs Samba: Samba 只是 Mamba+SWA 的交替排列,SambaY 在此基础上加入 YOCO 的 decoder-decoder 结构和 GMU 共享机制
  • vs CLA (Cross-Layer Attention): CLA 跨层共享 KV cache,SambaY 共享 SSM 的 token mixing 输出,避免了具象化循环状态

评分

  • 新颖性: ⭐⭐⭐⭐ GMU 设计优雅,decoder-hybrid-decoder 是 YOCO 的有意义扩展
  • 实验充分度: ⭐⭐⭐⭐⭐ 从 1B 到 3.8B 的缩放实验、长上下文检索、消融、推理基准、吞吐量测试,非常全面
  • 写作质量: ⭐⭐⭐⭐ 架构描述清晰,但部分符号较密集
  • 价值: ⭐⭐⭐⭐⭐ 微软 Phi4 系列的架构骨干,已有实际产品落地,对长 CoT 推理模型的高效部署有重要意义