The Curse of Depth in Large Language Models¶
会议: NeurIPS 2025
arXiv: 2502.05795
代码: https://github.com/wenfang-sun/LayerNorm-Scaling
领域: LLM/NLP
关键词: Pre-Layer Normalization, 深度诅咒, 方差控制, LayerNorm Scaling, Transformer
一句话总结¶
揭示 Pre-LN Transformer 中输出方差指数增长导致深层退化为恒等映射的根本原因,提出无参数的 LayerNorm Scaling(LNS)策略——仅在 LayerNorm 后乘以 \(1/\sqrt{\ell}\),将方差从指数增长压缩为多项式增长,在 130M-7B 全规模上稳定改进困惑度 5-8%。
研究背景与动机¶
- 领域现状:近年研究发现现代 LLM(Llama、Mistral、DeepSeek、Qwen)中几乎一半的深层 Transformer 块效率低——移除深层对性能几乎无影响。
- 现有痛点:LLM 训练极其昂贵,大量无效层意味着计算资源严重浪费,但根本原因缺乏系统理论解释。
- 核心矛盾:Pre-LN 虽然解决了训练稳定性问题(vs Post-LN),但引入了新问题——残差连接使输出方差随深度指数增长,LayerNorm 的归一化效果被稀释。
- 本文要解决什么:① 为什么 Pre-LN 的深层无效?② 数学上如何刻画?③ 如何用最简方案修复?
- 切入角度:从方差传播和梯度流的角度分析,发现方差 \(\sigma_{x_L}^2\) 从 \(\Theta(L)\) 到 \(\Theta(\exp(L))\) 指数增长,导致梯度范数 \(\|\partial y_L / \partial x_1\| \leq M\)(常数界),深层沦为恒等映射。
- 核心 idea 一句话:在 LayerNorm 后乘以按层递减的因子 \(1/\sqrt{\ell}\),将方差增长从指数压为多项式,让深层重新有效学习。
方法详解¶
整体框架¶
通过理论分析定位 Pre-LN 的方差指数增长问题,提出 LayerNorm Scaling(LNS)——在每一层 LayerNorm 输出后乘以 \(1/\sqrt{\ell}\) 的确定性缩放因子。
关键设计¶
- 方差增长的理论诊断(Lemma 3.2 + Theorem 3.3):
- 做什么:证明 Pre-LN 输出方差指数增长及其后果
- 核心思路:\(\sigma_{x_\ell}^2 = \sigma_{x_1}^2 \cdot \Theta(\prod_{k=1}^{\ell-1}(1 + 1/\sigma_{x_k}))\),界为 \(\Theta(L) \leq \sigma_{x_L}^2 \leq \Theta(\exp(L))\)。梯度范数 \(\|\partial y_L/\partial x_1\| \leq M\)(常数),深层等效为恒等映射
-
设计动机:通过 LLaMA2-7B 的 Jacobian 矩阵可视化验证——深层呈对角占优,非对角项消失
-
LayerNorm Scaling(LNS):
- 做什么:\(\tilde{h}^{(\ell)} = \text{LayerNorm}(h^{(\ell)}) \times \frac{1}{\sqrt{\ell}}\)
- 核心思路:LNS 将方差增长从指数压缩为多项式 \(\Theta(L) \leq \sigma_{x_L}^2 \leq \Theta(L^{2-\epsilon})\),梯度范数从有界常数变为 \(\omega(1)\)(随深度增长),深层恢复有效学习
-
设计动机:\(1/\sqrt{\ell}\) 而非 \(1/\ell\)——太小会导致初层梯度爆炸,\(\sqrt{\ell}\) 实现亚线性增长的平衡点
-
与 Scaled Initialization 的关系:
- 做什么:分析为什么不能仅靠初始化解决
- 核心思路:Scaled Initialization 仅在初始化时调整权重,但训练过程中方差仍然指数增长。LNS 在训练全程持续控制方差
- 设计动机:实验显示 LNS + Scaled Init 组合反而更差,两者控制方差的机制冲突
损失函数 / 训练策略¶
标准语言模型损失,零额外参数,零超参数。仅需一行代码改动:output * (1 / sqrt(layer_index))。建议使用 LNS 时移除 Scaled Initialization。
实验关键数据¶
主实验(预训练困惑度↓)¶
| 方法 | LLaMA-130M | LLaMA-250M | LLaMA-350M | LLaMA-1B |
|---|---|---|---|---|
| Post-LN | 26.95 | 1409.79 | 1368.33 | 1390.75 |
| DeepNorm | 27.17 | 22.77 | 1362.59 | 1409.08 |
| Mix-LN | 26.07 | 21.39 | 1363.21 | 1414.78 |
| Pre-LN(基线) | 26.73 | 21.92 | 19.58 | 17.02 |
| Pre-LN + LNS | 25.76 | 20.35 | 18.20 | 15.71 |
DeepNorm/Mix-LN 在大规模时发散,LNS 始终稳定。
消融实验(微调下游任务准确率↑)¶
| 方法 | MMLU | BoolQ | ARC-e | PIQA | HellaSwag | 平均 |
|---|---|---|---|---|---|---|
| Pre-LN (250M) | 24.93 | 38.35 | 40.15 | 63.55 | 26.34 | 36.93 |
| LNS (250M) | 27.08 | 58.17 | 45.24 | 67.38 | 32.81 | 43.14 |
| Pre-LN (1B) | 26.54 | 62.20 | 45.70 | 67.79 | 30.96 | 43.01 |
| LNS (1B) | 28.69 | 61.80 | 48.85 | 67.92 | 33.94 | 44.87 |
关键发现¶
- 方差控制效果显著:Pre-LN 深层方差增至 175,LNS 控制在 25 以内(7 倍降幅)
- 深层重新有效:LNS 下层级剪枝性能下降变得均匀,Angular Distance 从接近 0 升至 >0.6
- 规模一致性:60M-7B 全规模趋势一致,OLMo-7B 损失从 2.69→2.50(+7.1%)
- 预训练→微调迁移:预训练收益完全迁移到下游任务
亮点与洞察¶
- 理论深度:从方差→梯度→恒等映射的完整因果链,Theorem 3.3 和 4.2 严格证明了问题和解法
- 极致简洁:一行代码、零参数、零超参数的解决方案,是罕见的"简单到难以拒绝"的改进
- 实用性:所有使用 Pre-LN 的 LLM(几乎所有主流模型)都可直接受益
- Jacobian 可视化:深层的对角占优现象直观揭示了恒等映射行为
局限性 / 可改进方向¶
- \(1/\sqrt{\ell}\) 的选择缺乏严格最优性证明,仅为启发式+实验验证
- 与 Scaled Initialization 的冲突未深入分析
- ViT 中 LNS 的最优位置不同(Attn/MLP 后 vs LayerNorm 后),通用性受限
- 长序列场景下方差与序列长度的关系未讨论
相关工作与启发¶
- vs Mix-LN: 混合 Pre-LN/Post-LN 引入超参 α,且在 350M+ 规模发散。LNS 更简洁稳定
- vs DeepNorm: 调整残差权重来稳定训练,但在 1B 规模发散(PPL 1409)
- vs LayerScale: 每层学习缩放因子,引入可学习参数,实验中反而性能下降
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ 首次从方差增长根本视角分析 Pre-LN 深层低效,理论+实践价值兼备
- 实验充分度: ⭐⭐⭐⭐⭐ 多架构(LLaMA/OLMo/Qwen2.5/ViT)×多规模(130M-7B)×完整流程验证
- 写作质量: ⭐⭐⭐⭐ 问题→根因→解决方案→验证的逻辑清晰,Jacobian 可视化直观
- 价值: ⭐⭐⭐⭐⭐ 一行代码改动让所有 Pre-LN 模型受益,应成为 LLM 训练标准实践