Nemotron-Flash: Towards Latency-Optimal Hybrid Small Language Models¶
会议: NeurIPS 2025
arXiv: 2511.18890
代码: Hugging Face模型卡
领域: SLM 设计, 延迟优化
关键词: 混合算子, 深宽比, 权重归一化, 进化搜索
一句话总结¶
Nemotron-Flash 通过系统优化深宽比、进化搜索混合算子组合(DeltaNet+Mamba2+Attention)以及权重归一化训练,构建延迟最优的小语言模型家族,相比 Qwen3-1.7B/0.6B 分别实现 1.3×/1.9× 延迟下降与 +5.5% 平均准确率提升。
研究背景与动机¶
现有 SLM 设计主要追求参数高效(parameter-optimal),但参数量减少并不等价于实际设备延迟的等比降低。例如 MobileLLM 和 SmolLM 采用的深薄(deep-thin)架构虽然参数效率高,但在真实 GPU 推理时由于层数多导致延迟反而偏大。此外,面对近年涌现的高效注意力算子(Mamba、DeltaNet、GLA 等),学界对它们在混合模型中的协同效应缺乏系统探索,现有混合模型的算子组合依赖手工经验。
本文提出:SLM 设计应以实际设备延迟为第一优化目标,而非参数量。围绕这一目标,从架构设计和训练策略两个维度提供可泛化的方法论。
方法详解¶
整体框架¶
本文从三个层面改进 SLM 的精度-延迟权衡: 1. 深宽比优化:确定给定延迟预算下的最优模型深度与宽度 2. 混合算子搜索:进化搜索发现互补的算子组合 3. 训练改进:权重归一化 + 可学习 Meta Token
深宽比优化(Depth-Width Ratio)¶
作者训练了一系列 Llama 模型(深度 6/12/18/24/30,每个深度下变化宽度),在 100B token 的 Smollm-corpus 上训练,得到核心发现:
- 深模型参数效率高但延迟差:深薄模型在准确率-参数量曲线上占优,但在准确率-延迟曲线上反而劣势
- 存在最优深宽比:例如延迟预算 3 秒时,深度 12 达到了最优准确率
- 最优深宽比随延迟预算增大而增大:延迟预算越大,允许更深的模型
为了更精确地确定甜点深宽比,作者扩展了已有的缩放律,将模型大小 \(P\) 解耦为深度 \(D\) 和宽度 \(W\):
其中 \(a, b, c\) 控制各维度的贡献权重,\(\alpha, \beta, \gamma\) 控制边际递减速率。实验表明该缩放律可以外推到未见过的深宽配置,PPL 误差在 5.3% 以内。
混合算子搜索(Hybrid Operators)¶
算子评估:在 500M 模型上统一评估了 Mamba、Mamba2、GLA、DeltaNet、Gated DeltaNet、RWKV7、滑动窗口注意力(SWA)等算子。发现 DeltaNet 和 Gated DeltaNet 位于 PPL-延迟 Pareto 前沿。
算子组合探索:将不同算子与 Mamba2 或 Attention 配对构建混合模型: - DeltaNet/Gated DeltaNet + Mamba2 组合效果最佳,PPL 和 CR 准确率均优于纯模型 - 与 Attention 配对的提升不稳定 - 混合模型中不同算子的性能差距缩小(借助互补的记忆机制)
进化搜索框架: - 搜索代理:短训练 PPL 与全训练 PPL 之间 Spearman 相关性达 88.8%,可作为可靠代理 - 搜索空间:以 DeltaNet、Attention、Mamba2 为候选算子,搜索三阶段(早/中/晚期)的算子比例、每种 block 的 FFN 数量和重复次数 - 搜索算法:老化进化搜索(Aging Evolution),通过锦标赛选择 + 单因子变异 + 短训练评估迭代 - 搜索结果:发现延迟友好的架构以 [DeltaNet-FFN-Mamba2-FFN] 和 [Attention-FFN-Mamba2-FFN] 交替堆叠
对比同延迟的纯模型基线,搜索到的混合架构在 PPL 和 CR 准确率上均占优(51.04% vs 次优 DeltaNet 的 50.38%)。
权重归一化(Weight Normalization)¶
观察到标准训练的权重矩阵存在大幅度异常值,导致后期学习率低时相对权重更新变小、学习停滞。受 nGPT 启发,在每次训练迭代后将权重投影到单位范数球面:
- Case-1(作用于隐层特征的矩阵):\(\mathbf{W}_{i,:} \leftarrow \mathbf{W}_{i,:} / \|\mathbf{W}_{i,:}\|_2\)
- Case-2(输出加回隐层的矩阵):\(\mathbf{W}_{:,j} \leftarrow \mathbf{W}_{:,j} / \|\mathbf{W}_{:,j}\|_2\)
效果:跨 Llama、DeltaNet、Mamba2 三个架构家族平均提升 CR 准确率 +1.20%,PPL 降低 0.66。相比 nGPT 的完整方案,省去了激活归一化层带来的 20%+ 训练开销,在最终任务性能上效果相当。
Meta Token¶
在输入序列前拼接 256 个可学习 token,既缓解注意力 sink 问题(对 softmax 注意力),又为线性注意力提供学习到的缓存初始化。一致提升 +0.45% 准确率,开销可忽略。
模型配置¶
| 模型 | 参数量 | 隐藏维度 | Block 数 | 算子数 | 结构 |
|---|---|---|---|---|---|
| Nemotron-Flash-1B | 0.96B | 2048 | 12 | 24 | Block-1 × 4 + Block-2 × 2 交替 |
| Nemotron-Flash-3B | 2.7B | 3072 | 18 | 36 | Block-1 × 6 + Block-2 × 3 交替 |
其中 Block-1 = DeltaNet-FFN-Mamba2-FFN,Block-2 = Attention-FFN-Mamba2-FFN。训练使用 4.5T token,256 张 H100,Adam(无 weight decay),cosine 学习率 1e-3。
实验关键数据¶
Nemotron-Flash vs SOTA Base 模型(H100,decode 8k token,BS=1)¶
| 模型 | 参数 | 深度 | 延迟(s) | 最大BS吞吐(tok/s) | MMLU | CR | Math | Coding | Recall | 平均 |
|---|---|---|---|---|---|---|---|---|---|---|
| Qwen2.5-0.5B | 0.5B | 24 | 22.81 | 2,382 | 47.6 | 47.5 | 32.7 | 32.1 | 65.4 | 45.2% |
| Qwen3-0.6B | 0.6B | 28 | 27.55 | 160 | 52.4 | 48.9 | 36.9 | 24.3 | 62.9 | 44.1% |
| NF-1B | 0.96B | 12 | 14.45 | 7,289 | 44.6 | 54.5 | 34.9 | 37.9 | 67.1 | 49.6% |
| Qwen3-1.7B | 1.7B | 28 | 36.20 | 157 | 62.5 | 57.2 | 53.7 | 43.8 | 66.4 | 55.5% |
| Qwen2.5-3B | 3B | 36 | 49.40 | 459 | 65.6 | 58.9 | 53.8 | 49.5 | 73.0 | 59.0% |
| NF-3B | 2.7B | 18 | 28.71 | 2,939 | 61.2 | 61.0 | 57.6 | 53.3 | 73.3 | 61.0% |
- NF-1B vs Qwen3-0.6B:+5.5% 准确率,1.9× 低延迟,45.6× 高吞吐
- NF-3B vs Qwen2.5-3B:+2.0% 准确率,1.7× 低延迟,6.4× 高吞吐
- NF-3B 仅含 2 层全注意力但 Recall 73.3%,说明全 KV cache 覆盖所有层并非必要
Instruct 模型对比¶
| 模型 | 参数 | 延迟(s) | 吞吐(tok/s) | MMLU | GPQA | GSM8K | IFEval | 平均 |
|---|---|---|---|---|---|---|---|---|
| Qwen2.5-1.5B | 1.5B | 34.50 | 687 | 59.7 | 30.1 | 56.0 | 46.8 | 48.2% |
| Qwen3-1.7B | 1.7B | 36.20 | 157 | 60.2 | 28.3 | 64.9 | 31.3 | 46.2% |
| NF-3B-Inst | 2.7B | 28.71 | 2,939 | 60.3 | 29.5 | 69.5 | 52.0 | 52.8% |
权重归一化消融(1B 模型,100B token)¶
| 模型 | 设置 | Wiki PPL | CR 准确率 |
|---|---|---|---|
| Llama 1B | 无 wnorm | 18.67 | 53.81% |
| Llama 1B | 有 wnorm | 18.03 | 54.85%(+1.04) |
| DeltaNet 1B | 无 wnorm | 18.86 | 53.46% |
| DeltaNet 1B | 有 wnorm | 18.19 | 54.39%(+0.93) |
| Mamba2 1B | 无 wnorm | 18.44 | 53.30% |
| Mamba2 1B | 有 wnorm | 17.88 | 54.71%(+1.41) |
亮点与洞察¶
- 参数效率 ≠ 延迟效率:首次系统量化说明深薄模型在准确率-延迟权衡上劣于合理深宽比的模型
- 扩展缩放律:将缩放律解耦为深度和宽度两个自变量,提供确定甜点深宽比的原则性方法
- 混合算子黄金搭配:DeltaNet-Mamba2 交替搭配在延迟和精度上 Pareto 最优,搜索比手工设计更可靠
- 短训练代理有效:88.8% 的 Spearman 相关性使得进化搜索成本大幅降低
- 权重归一化简洁高效:比 nGPT 省 20% 训练开销,效果几乎相同;跨架构泛化
- 全注意力非必要:NF-3B 仅 3 层全注意力但 Recall 性能最优,支持混合注意力的 KV cache 节约论点
局限性¶
- 搜索空间限于 3 阶段 × 3 种算子,更大搜索空间可能发现更好架构
- 延迟评估仅在 A100/H100 上完成,不同硬件(GPU、TPU、端侧)的最优架构可能不同
- 滑动窗口固定 512,对需要长距离依赖的任务可能不足
- 缩放律拟合质量依赖训练点的覆盖范围,极端深宽比下外推精度存疑
相关工作¶
- SLM 设计:MobileLLM 强调深薄架构的参数效率,MiniCPM 提出缩放律指导 SLM 训练,本文通过延迟度量反驳了直接采用深薄模型的做法
- 高效注意力:Mamba/Mamba2 提出选择性状态空间模型,DeltaNet/GLA 提出线性注意力变体,但单独使用时 recall 能力有限
- 混合模型:Jamba、Hymba 等手动组合 Mamba 和 Attention 层,本文用进化搜索自动化了这一过程
评分¶
⭐⭐⭐⭐⭐ (5/5)