跳转至

The Curse of Depth in Large Language Models

会议: NeurIPS 2025
arXiv: 2502.05795
代码: https://github.com/wenfang-sun/LayerNorm-Scaling
领域: LLM/NLP
关键词: Pre-Layer Normalization, 深度诅咒, 方差控制, LayerNorm Scaling, Transformer

一句话总结

揭示 Pre-LN Transformer 中输出方差指数增长导致深层退化为恒等映射的根本原因,提出无参数的 LayerNorm Scaling(LNS)策略——仅在 LayerNorm 后乘以 \(1/\sqrt{\ell}\),将方差从指数增长压缩为多项式增长,在 130M-7B 全规模上稳定改进困惑度 5-8%。

研究背景与动机

  1. 领域现状:近年研究发现现代 LLM(Llama、Mistral、DeepSeek、Qwen)中几乎一半的深层 Transformer 块效率低——移除深层对性能几乎无影响。
  2. 现有痛点:LLM 训练极其昂贵,大量无效层意味着计算资源严重浪费,但根本原因缺乏系统理论解释。
  3. 核心矛盾:Pre-LN 虽然解决了训练稳定性问题(vs Post-LN),但引入了新问题——残差连接使输出方差随深度指数增长,LayerNorm 的归一化效果被稀释。
  4. 本文要解决什么:① 为什么 Pre-LN 的深层无效?② 数学上如何刻画?③ 如何用最简方案修复?
  5. 切入角度:从方差传播和梯度流的角度分析,发现方差 \(\sigma_{x_L}^2\)\(\Theta(L)\)\(\Theta(\exp(L))\) 指数增长,导致梯度范数 \(\|\partial y_L / \partial x_1\| \leq M\)(常数界),深层沦为恒等映射。
  6. 核心 idea 一句话:在 LayerNorm 后乘以按层递减的因子 \(1/\sqrt{\ell}\),将方差增长从指数压为多项式,让深层重新有效学习。

方法详解

整体框架

通过理论分析定位 Pre-LN 的方差指数增长问题,提出 LayerNorm Scaling(LNS)——在每一层 LayerNorm 输出后乘以 \(1/\sqrt{\ell}\) 的确定性缩放因子。

关键设计

  1. 方差增长的理论诊断(Lemma 3.2 + Theorem 3.3):
  2. 做什么:证明 Pre-LN 输出方差指数增长及其后果
  3. 核心思路:\(\sigma_{x_\ell}^2 = \sigma_{x_1}^2 \cdot \Theta(\prod_{k=1}^{\ell-1}(1 + 1/\sigma_{x_k}))\),界为 \(\Theta(L) \leq \sigma_{x_L}^2 \leq \Theta(\exp(L))\)。梯度范数 \(\|\partial y_L/\partial x_1\| \leq M\)(常数),深层等效为恒等映射
  4. 设计动机:通过 LLaMA2-7B 的 Jacobian 矩阵可视化验证——深层呈对角占优,非对角项消失

  5. LayerNorm Scaling(LNS):

  6. 做什么:\(\tilde{h}^{(\ell)} = \text{LayerNorm}(h^{(\ell)}) \times \frac{1}{\sqrt{\ell}}\)
  7. 核心思路:LNS 将方差增长从指数压缩为多项式 \(\Theta(L) \leq \sigma_{x_L}^2 \leq \Theta(L^{2-\epsilon})\),梯度范数从有界常数变为 \(\omega(1)\)(随深度增长),深层恢复有效学习
  8. 设计动机:\(1/\sqrt{\ell}\) 而非 \(1/\ell\)——太小会导致初层梯度爆炸,\(\sqrt{\ell}\) 实现亚线性增长的平衡点

  9. 与 Scaled Initialization 的关系:

  10. 做什么:分析为什么不能仅靠初始化解决
  11. 核心思路:Scaled Initialization 仅在初始化时调整权重,但训练过程中方差仍然指数增长。LNS 在训练全程持续控制方差
  12. 设计动机:实验显示 LNS + Scaled Init 组合反而更差,两者控制方差的机制冲突

损失函数 / 训练策略

标准语言模型损失,零额外参数,零超参数。仅需一行代码改动:output * (1 / sqrt(layer_index))。建议使用 LNS 时移除 Scaled Initialization。

实验关键数据

主实验(预训练困惑度↓)

方法 LLaMA-130M LLaMA-250M LLaMA-350M LLaMA-1B
Post-LN 26.95 1409.79 1368.33 1390.75
DeepNorm 27.17 22.77 1362.59 1409.08
Mix-LN 26.07 21.39 1363.21 1414.78
Pre-LN(基线) 26.73 21.92 19.58 17.02
Pre-LN + LNS 25.76 20.35 18.20 15.71

DeepNorm/Mix-LN 在大规模时发散,LNS 始终稳定。

消融实验(微调下游任务准确率↑)

方法 MMLU BoolQ ARC-e PIQA HellaSwag 平均
Pre-LN (250M) 24.93 38.35 40.15 63.55 26.34 36.93
LNS (250M) 27.08 58.17 45.24 67.38 32.81 43.14
Pre-LN (1B) 26.54 62.20 45.70 67.79 30.96 43.01
LNS (1B) 28.69 61.80 48.85 67.92 33.94 44.87

关键发现

  • 方差控制效果显著:Pre-LN 深层方差增至 175,LNS 控制在 25 以内(7 倍降幅)
  • 深层重新有效:LNS 下层级剪枝性能下降变得均匀,Angular Distance 从接近 0 升至 >0.6
  • 规模一致性:60M-7B 全规模趋势一致,OLMo-7B 损失从 2.69→2.50(+7.1%)
  • 预训练→微调迁移:预训练收益完全迁移到下游任务

亮点与洞察

  • 理论深度:从方差→梯度→恒等映射的完整因果链,Theorem 3.3 和 4.2 严格证明了问题和解法
  • 极致简洁:一行代码、零参数、零超参数的解决方案,是罕见的"简单到难以拒绝"的改进
  • 实用性:所有使用 Pre-LN 的 LLM(几乎所有主流模型)都可直接受益
  • Jacobian 可视化:深层的对角占优现象直观揭示了恒等映射行为

局限性 / 可改进方向

  • \(1/\sqrt{\ell}\) 的选择缺乏严格最优性证明,仅为启发式+实验验证
  • 与 Scaled Initialization 的冲突未深入分析
  • ViT 中 LNS 的最优位置不同(Attn/MLP 后 vs LayerNorm 后),通用性受限
  • 长序列场景下方差与序列长度的关系未讨论

相关工作与启发

  • vs Mix-LN: 混合 Pre-LN/Post-LN 引入超参 α,且在 350M+ 规模发散。LNS 更简洁稳定
  • vs DeepNorm: 调整残差权重来稳定训练,但在 1B 规模发散(PPL 1409)
  • vs LayerScale: 每层学习缩放因子,引入可学习参数,实验中反而性能下降

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 首次从方差增长根本视角分析 Pre-LN 深层低效,理论+实践价值兼备
  • 实验充分度: ⭐⭐⭐⭐⭐ 多架构(LLaMA/OLMo/Qwen2.5/ViT)×多规模(130M-7B)×完整流程验证
  • 写作质量: ⭐⭐⭐⭐ 问题→根因→解决方案→验证的逻辑清晰,Jacobian 可视化直观
  • 价值: ⭐⭐⭐⭐⭐ 一行代码改动让所有 Pre-LN 模型受益,应成为 LLM 训练标准实践