跳转至

SAFE: Finding Sparse and Flat Minima to Improve Pruning

会议: ICML2025
arXiv: 2506.06866
代码: JAX & PyTorch(论文附代码)
领域: 模型剪枝 / 模型压缩
关键词: 网络剪枝, 稀疏优化, 平坦极小值, SAM, ADMM, LLM剪枝

一句话总结

将剪枝问题建模为稀疏约束下的锐度感知优化问题,通过增广拉格朗日对偶法(ADMM)求解,同时实现稀疏性和平坦极小值,提升剪枝后网络的泛化性能和鲁棒性。

研究背景与动机

问题:神经网络剪枝后通常伴随不可避免的性能下降,即使经过大量近期工作的努力,恢复原始性能仍然具有挑战性。

关键观察

  • 泛化性能与损失面的平坦度密切相关(Keskar et al., 2017; Jiang et al., 2020)
  • SAM(Sharpness-Aware Minimization)通过显式正则化锐度来提升泛化,在多领域表现出色
  • 现有将锐度感知引入剪枝的工作(CrAM、Na et al.)仅是松散的启发式结合,尚未从优化角度系统地融合

动机:能否将锐度感知最小化和稀疏约束更有机地结合,通过严格的优化框架找到既稀疏又平坦的极小值

方法详解

问题建模

将剪枝形式化为带稀疏约束的 min-max 优化:

\[\min_{\|x\|_0 \leq d} \max_{\|\epsilon\|_2 \leq \rho} f(x + \epsilon)\]
  • 外层最小化:在稀疏约束下寻找最优权重 \(x\)
  • 内层最大化:在 \(\epsilon\)-邻域内找最大损失(即追求平坦性)
  • \(d\): 保留参数数量;\(\rho\): 扰动半径

ADMM 求解框架

变量分裂:引入辅助变量 \(z\),将稀疏约束与目标优化解耦:

\[\min_{x,z} \max_{\|\epsilon\|_2 \leq \rho} f(x+\epsilon) + I_{\|\cdot\|_0 \leq d}(z) \quad \text{s.t. } x = z\]

增广拉格朗日:添加惩罚项 \(\frac{\lambda}{2}\|x-z\|_2^2\),形成三步交替迭代:

  1. \(x\)-更新(锐度感知梯度下降): $\(x_k^{(t+1)} = x_k^{(t)} - \eta^{(t)} \left[ \nabla f\left(x_k^{(t)} + \rho \frac{\nabla f(x_k^{(t)})}{\|\nabla f(x_k^{(t)})\|_2}\right) + \lambda(x_k^{(t)} - z_k + u_k) \right]\)$

  2. \(z\)-更新(硬阈值投影): $\(z_{k+1} = \text{proj}_{\|\cdot\|_0 \leq d}(x_{k+1} + u_k)\)$

  3. \(u\)-更新(对偶变量上升): $\(u_{k+1} = u_k + x_{k+1} - z_{k+1}\)$

Safe⁺:广义投影扩展

引入正定对角矩阵 \(\mathbf{P}\) 替代欧氏距离,可整合多种高级显著性分数:

\(\mathbf{P}\) 选择 对应剪枝方法
\(\mathbf{I}\)(单位矩阵) 幅值剪枝(原始 Safe)
\(\text{diag}(\nabla^2 f(x))\) OBD(最优脑损伤)
\(\text{diag}(\nabla f \cdot \nabla f^\top)\) SNIP(梯度敏感度)
\(\text{diag}(\mathbf{A}^\top \mathbf{A})\) Wanda(激活感知)

收敛保证

论文证明了在标准假设(下界、\(\beta\)-光滑、\(\mu\)-弱凸)下: - \(x\)-更新序列收敛到增广拉格朗日的稳定点(Lemma 3.5) - Safe 整体收敛到稀疏约束优化问题的 \(\delta\)-稳定点(Corollary 3.6)

实用技巧

  • 惩罚参数 \(\lambda\) 使用余弦调度从 0 逐渐增大到目标值,减少训练初期的约束干扰

实验关键数据

图像分类(CIFAR-10/100,训练中剪枝)

在 VGG-19、ResNet-20/32 上,Safe 在绝大多数稀疏度设置下优于 PBW、GMP、LTH、ADMM、MLPrune 等基线。在极端稀疏度(99.5%)下优势更明显。无需额外重训练,仅用 BN 微调。

LLM 后训练剪枝(Perplexity ↓)

模型 稀疏度 SparseGPT Wanda ALPS Safe Safe⁺
LLaMA-2 7B 50% 6.99/9.20 6.92/9.23 6.87/8.98 6.78/8.93 6.56/8.71
LLaMA-2 7B 60% 10.19/12.86 10.75/13.87 9.55/11.24 9.20/11.51 8.30/10.59
LLaMA-2 13B 50% 6.06/8.20 5.98/8.28 5.96/8.09 5.76/7.85 5.67/7.74
LLaMA-3 8B 50% 9.36/13.96 9.71/14.88 9.05/13.40 9.59/14.60 8.62/13.26

Safe⁺ 在所有模型、所有稀疏度(50%/60%/4:8/2:4)上均超越 SOTA 基线。

噪声鲁棒性(ResNet-20 on CIFAR-10)

噪声比 稀疏度 ADMM Safe
25% 70% 77.00 90.58
50% 70% 59.18 86.51
75% 70% 32.62 67.01

Safe 在标签噪声下比 ADMM 高出 +10% ~ +30% 准确率。同时在 CIFAR-10C 常见腐蚀和 PGD 对抗攻击下也表现更优。

亮点与洞察

  1. 优化视角新颖:首次将锐度感知(SAM)与稀疏约束通过增广拉格朗日框架严格结合,而非启发式拼接
  2. 理论有保证:提供了完整的收敛性证明,相比多数剪枝方法仅基于直觉
  3. Safe⁺ 框架统一:广义投影矩阵 \(\mathbf{P}\) 将 OBD、SNIP、Wanda 等经典方法纳入统一框架
  4. 鲁棒性突出:对标签噪声、图像腐蚀、对抗攻击均展现出超越基线的鲁棒性
  5. 效率优势:比 ALPS(同为 ADMM-based)快 2.54 倍;图像分类无需重训练

局限与展望

  1. 图像分类仅在小模型验证:CIFAR-10/100 + VGG/ResNet,未在 ImageNet 或更大规模模型上测试训练中剪枝
  2. LLM 实验局限:Safe⁺ 的 LLM 剪枝依赖逐层重构误差最小化,未端到端优化
  3. 结构化稀疏有限:主要实验为非结构化剪枝(50%/60%)和半结构化(2:4/4:8),未涉及通道级剪枝
  4. \(\lambda\) 调度敏感:惩罚参数调度策略需额外调参,论文中使用余弦调度但未充分分析其他策略
  5. 二阶信息成本:Safe⁺ 使用 Hessian 对角等信息虽效果好但引入额外计算

相关工作与启发

  • SAM(Foret et al., 2021):锐度感知最小化的基础方法
  • ADMM 剪枝(Zhang et al., 2018):ADMM 框架剪枝的先驱,Safe 在此基础上加入平坦性目标
  • CrAM(Peste et al., 2022):压缩感知最小化,启发式结合 SAM 和剪枝
  • Wanda(Sun et al., 2024):激活感知 LLM 剪枝,被 Safe⁺ 统一到广义投影中
  • ALPS(Meng et al., 2024):同为 ADMM-based LLM 剪枝,Safe 在效果和效率上均有优势

评分

  • 新颖性: ⭐⭐⭐⭐ — 优化建模新颖,但各组件(SAM+ADMM)已有
  • 实验充分度: ⭐⭐⭐⭐ — 图像+LLM+鲁棒性全面,但大规模训练中剪枝实验缺乏
  • 写作质量: ⭐⭐⭐⭐⭐ — 推导清晰,理论完整,实验组织有条理
  • 价值: ⭐⭐⭐⭐ — 为剪枝提供了理论更扎实的优化框架,Safe⁺ 的统一视角有启发性

相关论文