SAFE: Finding Sparse and Flat Minima to Improve Pruning¶
会议: ICML2025
arXiv: 2506.06866
代码: JAX & PyTorch(论文附代码)
领域: 模型剪枝 / 模型压缩
关键词: 网络剪枝, 稀疏优化, 平坦极小值, SAM, ADMM, LLM剪枝
一句话总结¶
将剪枝问题建模为稀疏约束下的锐度感知优化问题,通过增广拉格朗日对偶法(ADMM)求解,同时实现稀疏性和平坦极小值,提升剪枝后网络的泛化性能和鲁棒性。
研究背景与动机¶
问题:神经网络剪枝后通常伴随不可避免的性能下降,即使经过大量近期工作的努力,恢复原始性能仍然具有挑战性。
关键观察:
- 泛化性能与损失面的平坦度密切相关(Keskar et al., 2017; Jiang et al., 2020)
- SAM(Sharpness-Aware Minimization)通过显式正则化锐度来提升泛化,在多领域表现出色
- 现有将锐度感知引入剪枝的工作(CrAM、Na et al.)仅是松散的启发式结合,尚未从优化角度系统地融合
动机:能否将锐度感知最小化和稀疏约束更有机地结合,通过严格的优化框架找到既稀疏又平坦的极小值?
方法详解¶
问题建模¶
将剪枝形式化为带稀疏约束的 min-max 优化:
- 外层最小化:在稀疏约束下寻找最优权重 \(x\)
- 内层最大化:在 \(\epsilon\)-邻域内找最大损失(即追求平坦性)
- \(d\): 保留参数数量;\(\rho\): 扰动半径
ADMM 求解框架¶
变量分裂:引入辅助变量 \(z\),将稀疏约束与目标优化解耦:
增广拉格朗日:添加惩罚项 \(\frac{\lambda}{2}\|x-z\|_2^2\),形成三步交替迭代:
-
\(x\)-更新(锐度感知梯度下降): $\(x_k^{(t+1)} = x_k^{(t)} - \eta^{(t)} \left[ \nabla f\left(x_k^{(t)} + \rho \frac{\nabla f(x_k^{(t)})}{\|\nabla f(x_k^{(t)})\|_2}\right) + \lambda(x_k^{(t)} - z_k + u_k) \right]\)$
-
\(z\)-更新(硬阈值投影): $\(z_{k+1} = \text{proj}_{\|\cdot\|_0 \leq d}(x_{k+1} + u_k)\)$
-
\(u\)-更新(对偶变量上升): $\(u_{k+1} = u_k + x_{k+1} - z_{k+1}\)$
Safe⁺:广义投影扩展¶
引入正定对角矩阵 \(\mathbf{P}\) 替代欧氏距离,可整合多种高级显著性分数:
| \(\mathbf{P}\) 选择 | 对应剪枝方法 |
|---|---|
| \(\mathbf{I}\)(单位矩阵) | 幅值剪枝(原始 Safe) |
| \(\text{diag}(\nabla^2 f(x))\) | OBD(最优脑损伤) |
| \(\text{diag}(\nabla f \cdot \nabla f^\top)\) | SNIP(梯度敏感度) |
| \(\text{diag}(\mathbf{A}^\top \mathbf{A})\) | Wanda(激活感知) |
收敛保证¶
论文证明了在标准假设(下界、\(\beta\)-光滑、\(\mu\)-弱凸)下: - \(x\)-更新序列收敛到增广拉格朗日的稳定点(Lemma 3.5) - Safe 整体收敛到稀疏约束优化问题的 \(\delta\)-稳定点(Corollary 3.6)
实用技巧¶
- 惩罚参数 \(\lambda\) 使用余弦调度从 0 逐渐增大到目标值,减少训练初期的约束干扰
实验关键数据¶
图像分类(CIFAR-10/100,训练中剪枝)¶
在 VGG-19、ResNet-20/32 上,Safe 在绝大多数稀疏度设置下优于 PBW、GMP、LTH、ADMM、MLPrune 等基线。在极端稀疏度(99.5%)下优势更明显。无需额外重训练,仅用 BN 微调。
LLM 后训练剪枝(Perplexity ↓)¶
| 模型 | 稀疏度 | SparseGPT | Wanda | ALPS | Safe | Safe⁺ |
|---|---|---|---|---|---|---|
| LLaMA-2 7B | 50% | 6.99/9.20 | 6.92/9.23 | 6.87/8.98 | 6.78/8.93 | 6.56/8.71 |
| LLaMA-2 7B | 60% | 10.19/12.86 | 10.75/13.87 | 9.55/11.24 | 9.20/11.51 | 8.30/10.59 |
| LLaMA-2 13B | 50% | 6.06/8.20 | 5.98/8.28 | 5.96/8.09 | 5.76/7.85 | 5.67/7.74 |
| LLaMA-3 8B | 50% | 9.36/13.96 | 9.71/14.88 | 9.05/13.40 | 9.59/14.60 | 8.62/13.26 |
Safe⁺ 在所有模型、所有稀疏度(50%/60%/4:8/2:4)上均超越 SOTA 基线。
噪声鲁棒性(ResNet-20 on CIFAR-10)¶
| 噪声比 | 稀疏度 | ADMM | Safe |
|---|---|---|---|
| 25% | 70% | 77.00 | 90.58 |
| 50% | 70% | 59.18 | 86.51 |
| 75% | 70% | 32.62 | 67.01 |
Safe 在标签噪声下比 ADMM 高出 +10% ~ +30% 准确率。同时在 CIFAR-10C 常见腐蚀和 PGD 对抗攻击下也表现更优。
亮点与洞察¶
- 优化视角新颖:首次将锐度感知(SAM)与稀疏约束通过增广拉格朗日框架严格结合,而非启发式拼接
- 理论有保证:提供了完整的收敛性证明,相比多数剪枝方法仅基于直觉
- Safe⁺ 框架统一:广义投影矩阵 \(\mathbf{P}\) 将 OBD、SNIP、Wanda 等经典方法纳入统一框架
- 鲁棒性突出:对标签噪声、图像腐蚀、对抗攻击均展现出超越基线的鲁棒性
- 效率优势:比 ALPS(同为 ADMM-based)快 2.54 倍;图像分类无需重训练
局限与展望¶
- 图像分类仅在小模型验证:CIFAR-10/100 + VGG/ResNet,未在 ImageNet 或更大规模模型上测试训练中剪枝
- LLM 实验局限:Safe⁺ 的 LLM 剪枝依赖逐层重构误差最小化,未端到端优化
- 结构化稀疏有限:主要实验为非结构化剪枝(50%/60%)和半结构化(2:4/4:8),未涉及通道级剪枝
- \(\lambda\) 调度敏感:惩罚参数调度策略需额外调参,论文中使用余弦调度但未充分分析其他策略
- 二阶信息成本:Safe⁺ 使用 Hessian 对角等信息虽效果好但引入额外计算
相关工作与启发¶
- SAM(Foret et al., 2021):锐度感知最小化的基础方法
- ADMM 剪枝(Zhang et al., 2018):ADMM 框架剪枝的先驱,Safe 在此基础上加入平坦性目标
- CrAM(Peste et al., 2022):压缩感知最小化,启发式结合 SAM 和剪枝
- Wanda(Sun et al., 2024):激活感知 LLM 剪枝,被 Safe⁺ 统一到广义投影中
- ALPS(Meng et al., 2024):同为 ADMM-based LLM 剪枝,Safe 在效果和效率上均有优势
评分¶
- 新颖性: ⭐⭐⭐⭐ — 优化建模新颖,但各组件(SAM+ADMM)已有
- 实验充分度: ⭐⭐⭐⭐ — 图像+LLM+鲁棒性全面,但大规模训练中剪枝实验缺乏
- 写作质量: ⭐⭐⭐⭐⭐ — 推导清晰,理论完整,实验组织有条理
- 价值: ⭐⭐⭐⭐ — 为剪枝提供了理论更扎实的优化框架,Safe⁺ 的统一视角有启发性
相关论文¶
- [ICML 2025] Random Initialization of Gated Sparse Adapters (RIGSA)
- [ACL 2025] Outlier-Safe Pre-Training for Robust 4-Bit Quantization of Large Language Models
- [ICML 2025] Instruction-Following Pruning for Large Language Models
- [ICML 2025] DLP: Dynamic Layerwise Pruning in Large Language Models
- [ICML 2025] FGFP: A Fractional Gaussian Filter and Pruning for DNN Compression