跳转至

Elastic Weight Consolidation Done Right for Continual Learning

会议: CVPR 2026 arXiv: 2603.18596 代码: https://github.com/scarlet0703/EWC-DR 领域: 其他 关键词: 持续学习, 灾难性遗忘, 弹性权重巩固, Fisher信息矩阵, 权重正则化

一句话总结

本文从梯度视角系统分析了 EWC 及其变体在权重重要性估计上的根本缺陷(EWC 的梯度消失和 MAS 的冗余保护),并提出了一个极其简单的 Logits Reversal 操作来修正 Fisher 信息矩阵的计算,在无样例类增量学习和多模态持续指令微调任务上大幅超越原始 EWC 及其所有变体。

研究背景与动机

持续学习(Continual Learning)要求模型按顺序学习多个任务,但神经网络在学习新任务时会灾难性地遗忘旧任务知识。解决这一问题的主流方法之一是权重正则化:评估每个参数对旧任务的重要性,训练新任务时惩罚对重要参数的修改。

EWC(Elastic Weight Consolidation)是这类方法的奠基之作,通过 Fisher 信息矩阵(FIM)来估计参数重要性,广泛应用于图像分类、指令微调、目标检测等场景。然而,EWC 在实际实验中一直表现不佳,已有多项研究指出其 FIM 近似不够准确,但没有人从根本上分析过 EWC 性能差的真正原因

本文的核心洞察是:EWC 的问题不仅仅是"FIM 近似不准",而是存在两个结构性缺陷——梯度消失导致重要参数被低估,以及 MAS 等变体引入的冗余保护导致不相关参数被过度约束。作者提出的修复方案——Logits Reversal——只需要在计算 FIM 时对 logits 取反,即可同时解决这两个问题。

方法详解

整体框架

EWC-DR 沿用标准 EWC 的学习流程:训练完任务 \(t-1\) 后,用训练数据计算参数重要性矩阵 \(\Omega^{t-1}\),学习新任务 \(t\) 时添加正则化损失 \(\mathcal{L}_{reg} = \frac{\lambda}{2} \sum_i \Omega_i^{t-1}(\theta_i^{t-1} - \theta_i^t)^2\)。本文的改进仅在于如何计算 \(\Omega\)

关键设计

  1. 梯度消失分析(EWC 的核心缺陷):
  2. 做什么:分析 EWC 中 FIM 值为何偏低
  3. 核心思路:EWC 的 FIM 计算基于交叉熵损失的梯度平方。对于 FC 层权重 \(w_k\),重要性为 \(\Omega_{w_k}^{EWC} = \mathbb{E}[(p_k - y_k)^2 \cdot (\frac{\partial z_k}{\partial w_k})^2]\)。当模型对正确类别 \(c\) 的预测概率 \(p_c \to 1\) 时,\((p_c - 1) \to 0\),导致对应梯度趋近于零;同时其他类的 \(p_k \to 0\),梯度同样消失
  4. 设计动机:模型在训练结束时通常对训练样本有很高的置信度,这恰恰是计算 FIM 的时刻。因此 EWC 系统性地低估了所有参数的重要性,无法有效保留旧任务知识

  5. 冗余保护分析(MAS 的核心缺陷):

  6. 做什么:分析 MAS 为何过度保护无关参数
  7. 核心思路:MAS 使用 \(\ell_2\) 范数损失代替交叉熵,重要性为 \(\Omega_{w_k}^{MAS} = \frac{|z_k|}{\|\mathbf{z}\|_2} \cdot |\frac{\partial z_k}{\partial w_k}|\)。由于 logits 无界,一个大的负 logit \(z_k\)(对应极低预测概率)会产生很高的重要性分数
  8. 设计动机:这些极端负 logits 对输出概率几乎没有影响,保护这些参数毫无必要,反而限制了模型学习新任务的可塑性

  9. Logits Reversal(本文核心方法):

  10. 做什么:在计算 FIM 时,将 logits \(z_k\) 取反为 \(\tilde{z}_k = -z_k\)
  11. 核心思路:取反后的 softmax 输出为 \(\tilde{p}_k = \frac{e^{-z_k}}{\sum_j e^{-z_j}}\),对应的重要性变为 \(\Omega_{w_k}^{LR} = \mathbb{E}[(y_k - \tilde{p}_k)^2 \cdot (\frac{\partial \tilde{z}_k}{\partial w_k})^2]\)。关键性质是 \(\frac{\partial \tilde{p}_k}{\partial z_k} < 0\):当 \(z_c\) 增大(高置信正确预测)时,\(\tilde{p}_c\) 减小,\((1 - \tilde{p}_c)\) 增大,从而放大正确类别的重要性
  12. 设计动机:LR 同时解决了两个问题——(1) 高置信预测下梯度不再消失,FIM 能准确反映参数重要性;(2) 错误类别的 \(\tilde{p}_k\) 很小,不会产生冗余的保护信号。整个操作只需一行代码改动

损失函数 / 训练策略

训练损失保持标准 EWC 形式不变:\(\mathcal{L}_{total} = \mathcal{L}_{CE} + \frac{\lambda}{2} \sum_i \Omega_i^{LR}(\theta_i^{t-1} - \theta_i^t)^2\)。唯一改变是 \(\Omega\) 的计算方式。

实验关键数据

主实验

数据集 设置 指标 EWC Online EWC MAS EWC-DR 提升(vs EWC)
CIFAR-100 Big-start T=5 \(A_{last}\) 14.61 29.70 35.37 50.23 +35.62
CIFAR-100 Big-start T=5 \(A_{avg}\) 32.82 45.65 48.32 63.75 +30.93
ImageNet-Sub Big-start T=5 \(A_{last}\) 11.44 23.56 21.06 66.18 +54.74
ImageNet-Sub Big-start T=5 \(A_{avg}\) 26.57 46.68 42.59 76.00 +49.43
Tiny-ImageNet Big-start T=5 \(A_{last}\) 9.74 27.02 25.53 38.24 +28.50
MCIT (VCR后) 增量精度 \(A_t\) 42.99 52.59 +9.60

消融实验

配置 关键指标 说明
EWC (原始FIM) FC层重要性矩阵几乎全黑 梯度消失导致所有类别重要性极低
MAS (ℓ2范数) GT类+非GT类均高亮 对class 0和class 4产生冗余保护
EWC-DR (LR) 仅GT类(class 2)高亮 选择性且判别性的重要性估计
MCIT: EWC遗忘率 NLVR2任务90.66% 严重灾难性遗忘
MCIT: EWC-DR遗忘率 NLVR2任务27.48% 遗忘显著降低,可塑性保持

关键发现

  • EWC-DR 在所有 18 个 EFCIL 设置中取得最佳结果,最大提升幅度达 \(A_{last}\) +53.18%、\(A_{avg}\) +55.47%
  • 临界差异(CD)分析证实 EWC-DR 的提升具有统计显著性(CD=1.438,显著性0.05)
  • 多模态持续指令微调中,EWC-DR 在不损失新任务学习能力的同时,大幅降低遗忘率

亮点与洞察

  • 极其优雅的分析框架:从梯度角度统一审视 EWC 家族的缺陷,发现了两个之前被忽略的根本问题
  • 修复方案极度简洁:只需一行代码(logits 取反)就能大幅提升性能,体现了"找对问题比设计复杂方案更重要"
  • 重要性矩阵的可视化分析非常直观:EWC 全黑、MAS 过度高亮、EWC-DR 精确聚焦,一目了然

局限性 / 可改进方向

  • 理论分析聚焦于 FC 层权重,对中间层参数的影响通过反向传播间接作用,缺乏直接分析
  • 仅与 EWC 家族(EWC、Online EWC、SI、MAS)比较,缺少与知识蒸馏、架构扩展等其他 CL 方法类别的系统对比
  • Logits Reversal 的理论最优性没有严格证明,可能存在更优的 logit 变换

相关工作与启发

  • 与 Online EWC 的对比表明,在线累积重要性权重并不能根本解决梯度消失问题
  • MAS 虽然避免了梯度消失,但引入了新问题(冗余保护),说明损失函数的选择需要更加审慎
  • 该工作提示我们:经典方法的性能不佳可能不是"方法不好",而是"实现有bug"——从基本原理重新审视可能找到简单而高效的改进

评分

  • 新颖性: ⭐⭐⭐⭐ 分析视角新颖,但解决方案(取反 logits)的技术贡献偏轻量
  • 实验充分度: ⭐⭐⭐⭐⭐ 三个数据集×三种任务划分×两种设置 + MCIT 实验 + 统计检验
  • 写作质量: ⭐⭐⭐⭐⭐ 逻辑清晰,从分析到方法到实验一气呵成,可视化优秀
  • 价值: ⭐⭐⭐⭐ 对 EWC 研究社区有重要参考价值,方法简单易用