跳转至

Nemotron-Flash: Towards Latency-Optimal Hybrid Small Language Models

会议: NeurIPS 2025
arXiv: 2511.18890
代码: Hugging Face模型卡
领域: SLM 设计, 延迟优化
关键词: 混合算子, 深宽比, 权重归一化, 进化搜索

一句话总结

Nemotron-Flash 通过系统优化深宽比、进化搜索混合算子组合(DeltaNet+Mamba2+Attention)以及权重归一化训练,构建延迟最优的小语言模型家族,相比 Qwen3-1.7B/0.6B 分别实现 1.3×/1.9× 延迟下降与 +5.5% 平均准确率提升。

研究背景与动机

现有 SLM 设计主要追求参数高效(parameter-optimal),但参数量减少并不等价于实际设备延迟的等比降低。例如 MobileLLM 和 SmolLM 采用的深薄(deep-thin)架构虽然参数效率高,但在真实 GPU 推理时由于层数多导致延迟反而偏大。此外,面对近年涌现的高效注意力算子(Mamba、DeltaNet、GLA 等),学界对它们在混合模型中的协同效应缺乏系统探索,现有混合模型的算子组合依赖手工经验。

本文提出:SLM 设计应以实际设备延迟为第一优化目标,而非参数量。围绕这一目标,从架构设计和训练策略两个维度提供可泛化的方法论。

方法详解

整体框架

本文从三个层面改进 SLM 的精度-延迟权衡: 1. 深宽比优化:确定给定延迟预算下的最优模型深度与宽度 2. 混合算子搜索:进化搜索发现互补的算子组合 3. 训练改进:权重归一化 + 可学习 Meta Token

深宽比优化(Depth-Width Ratio)

作者训练了一系列 Llama 模型(深度 6/12/18/24/30,每个深度下变化宽度),在 100B token 的 Smollm-corpus 上训练,得到核心发现:

  • 深模型参数效率高但延迟差:深薄模型在准确率-参数量曲线上占优,但在准确率-延迟曲线上反而劣势
  • 存在最优深宽比:例如延迟预算 3 秒时,深度 12 达到了最优准确率
  • 最优深宽比随延迟预算增大而增大:延迟预算越大,允许更深的模型

为了更精确地确定甜点深宽比,作者扩展了已有的缩放律,将模型大小 \(P\) 解耦为深度 \(D\) 和宽度 \(W\)

\[\mathcal{L}(D, W, N) = \mathcal{L}_0 + aD^{-\alpha} + bW^{-\beta} + cN^{-\gamma}\]

其中 \(a, b, c\) 控制各维度的贡献权重,\(\alpha, \beta, \gamma\) 控制边际递减速率。实验表明该缩放律可以外推到未见过的深宽配置,PPL 误差在 5.3% 以内。

混合算子搜索(Hybrid Operators)

算子评估:在 500M 模型上统一评估了 Mamba、Mamba2、GLA、DeltaNet、Gated DeltaNet、RWKV7、滑动窗口注意力(SWA)等算子。发现 DeltaNet 和 Gated DeltaNet 位于 PPL-延迟 Pareto 前沿。

算子组合探索:将不同算子与 Mamba2 或 Attention 配对构建混合模型: - DeltaNet/Gated DeltaNet + Mamba2 组合效果最佳,PPL 和 CR 准确率均优于纯模型 - 与 Attention 配对的提升不稳定 - 混合模型中不同算子的性能差距缩小(借助互补的记忆机制)

进化搜索框架: - 搜索代理:短训练 PPL 与全训练 PPL 之间 Spearman 相关性达 88.8%,可作为可靠代理 - 搜索空间:以 DeltaNet、Attention、Mamba2 为候选算子,搜索三阶段(早/中/晚期)的算子比例、每种 block 的 FFN 数量和重复次数 - 搜索算法:老化进化搜索(Aging Evolution),通过锦标赛选择 + 单因子变异 + 短训练评估迭代 - 搜索结果:发现延迟友好的架构以 [DeltaNet-FFN-Mamba2-FFN] 和 [Attention-FFN-Mamba2-FFN] 交替堆叠

对比同延迟的纯模型基线,搜索到的混合架构在 PPL 和 CR 准确率上均占优(51.04% vs 次优 DeltaNet 的 50.38%)。

权重归一化(Weight Normalization)

观察到标准训练的权重矩阵存在大幅度异常值,导致后期学习率低时相对权重更新变小、学习停滞。受 nGPT 启发,在每次训练迭代后将权重投影到单位范数球面:

  • Case-1(作用于隐层特征的矩阵):\(\mathbf{W}_{i,:} \leftarrow \mathbf{W}_{i,:} / \|\mathbf{W}_{i,:}\|_2\)
  • Case-2(输出加回隐层的矩阵):\(\mathbf{W}_{:,j} \leftarrow \mathbf{W}_{:,j} / \|\mathbf{W}_{:,j}\|_2\)

效果:跨 Llama、DeltaNet、Mamba2 三个架构家族平均提升 CR 准确率 +1.20%,PPL 降低 0.66。相比 nGPT 的完整方案,省去了激活归一化层带来的 20%+ 训练开销,在最终任务性能上效果相当。

Meta Token

在输入序列前拼接 256 个可学习 token,既缓解注意力 sink 问题(对 softmax 注意力),又为线性注意力提供学习到的缓存初始化。一致提升 +0.45% 准确率,开销可忽略。

模型配置

模型 参数量 隐藏维度 Block 数 算子数 结构
Nemotron-Flash-1B 0.96B 2048 12 24 Block-1 × 4 + Block-2 × 2 交替
Nemotron-Flash-3B 2.7B 3072 18 36 Block-1 × 6 + Block-2 × 3 交替

其中 Block-1 = DeltaNet-FFN-Mamba2-FFN,Block-2 = Attention-FFN-Mamba2-FFN。训练使用 4.5T token,256 张 H100,Adam(无 weight decay),cosine 学习率 1e-3。

实验关键数据

Nemotron-Flash vs SOTA Base 模型(H100,decode 8k token,BS=1)

模型 参数 深度 延迟(s) 最大BS吞吐(tok/s) MMLU CR Math Coding Recall 平均
Qwen2.5-0.5B 0.5B 24 22.81 2,382 47.6 47.5 32.7 32.1 65.4 45.2%
Qwen3-0.6B 0.6B 28 27.55 160 52.4 48.9 36.9 24.3 62.9 44.1%
NF-1B 0.96B 12 14.45 7,289 44.6 54.5 34.9 37.9 67.1 49.6%
Qwen3-1.7B 1.7B 28 36.20 157 62.5 57.2 53.7 43.8 66.4 55.5%
Qwen2.5-3B 3B 36 49.40 459 65.6 58.9 53.8 49.5 73.0 59.0%
NF-3B 2.7B 18 28.71 2,939 61.2 61.0 57.6 53.3 73.3 61.0%
  • NF-1B vs Qwen3-0.6B:+5.5% 准确率,1.9× 低延迟,45.6× 高吞吐
  • NF-3B vs Qwen2.5-3B:+2.0% 准确率,1.7× 低延迟,6.4× 高吞吐
  • NF-3B 仅含 2 层全注意力但 Recall 73.3%,说明全 KV cache 覆盖所有层并非必要

Instruct 模型对比

模型 参数 延迟(s) 吞吐(tok/s) MMLU GPQA GSM8K IFEval 平均
Qwen2.5-1.5B 1.5B 34.50 687 59.7 30.1 56.0 46.8 48.2%
Qwen3-1.7B 1.7B 36.20 157 60.2 28.3 64.9 31.3 46.2%
NF-3B-Inst 2.7B 28.71 2,939 60.3 29.5 69.5 52.0 52.8%

权重归一化消融(1B 模型,100B token)

模型 设置 Wiki PPL CR 准确率
Llama 1B 无 wnorm 18.67 53.81%
Llama 1B 有 wnorm 18.03 54.85%(+1.04)
DeltaNet 1B 无 wnorm 18.86 53.46%
DeltaNet 1B 有 wnorm 18.19 54.39%(+0.93)
Mamba2 1B 无 wnorm 18.44 53.30%
Mamba2 1B 有 wnorm 17.88 54.71%(+1.41)

亮点与洞察

  1. 参数效率 ≠ 延迟效率:首次系统量化说明深薄模型在准确率-延迟权衡上劣于合理深宽比的模型
  2. 扩展缩放律:将缩放律解耦为深度和宽度两个自变量,提供确定甜点深宽比的原则性方法
  3. 混合算子黄金搭配:DeltaNet-Mamba2 交替搭配在延迟和精度上 Pareto 最优,搜索比手工设计更可靠
  4. 短训练代理有效:88.8% 的 Spearman 相关性使得进化搜索成本大幅降低
  5. 权重归一化简洁高效:比 nGPT 省 20% 训练开销,效果几乎相同;跨架构泛化
  6. 全注意力非必要:NF-3B 仅 3 层全注意力但 Recall 性能最优,支持混合注意力的 KV cache 节约论点

局限性

  • 搜索空间限于 3 阶段 × 3 种算子,更大搜索空间可能发现更好架构
  • 延迟评估仅在 A100/H100 上完成,不同硬件(GPU、TPU、端侧)的最优架构可能不同
  • 滑动窗口固定 512,对需要长距离依赖的任务可能不足
  • 缩放律拟合质量依赖训练点的覆盖范围,极端深宽比下外推精度存疑

相关工作

  • SLM 设计:MobileLLM 强调深薄架构的参数效率,MiniCPM 提出缩放律指导 SLM 训练,本文通过延迟度量反驳了直接采用深薄模型的做法
  • 高效注意力:Mamba/Mamba2 提出选择性状态空间模型,DeltaNet/GLA 提出线性注意力变体,但单独使用时 recall 能力有限
  • 混合模型:Jamba、Hymba 等手动组合 Mamba 和 Attention 层,本文用进化搜索自动化了这一过程

评分

⭐⭐⭐⭐⭐ (5/5)