跳转至

Gradient-Guided Annealing for Domain Generalization

会议: CVPR 2025
arXiv: 2502.20162
代码: https://github.com/aristotelisballas/GGA
领域: 其他
关键词: 域泛化、梯度对齐、模拟退火、域偏移、早期训练

一句话总结

提出GGA方法,在训练早期通过模拟退火搜索参数空间中梯度跨域对齐的点(最小化域间梯度余弦相似度的最小值),引导模型在优化初期找到域不变特征的起始点,从而在无需数据增强的情况下提升域泛化性,可与现有DG方法组合获得显著提升。

研究背景与动机

领域现状:域泛化(DG)旨在训练能泛化到未见域的模型。现有方法包括数据增强、元学习、域对齐等,但大量实验表明,简单的ERM基线搭配好的训练策略就能超过很多DG方法。

现有痛点:训练初期的几步梯度更新对模型最终的泛化能力有决定性影响。域间的梯度冲突(不同域的损失梯度方向不一致)会将优化推向捕获域特定特征而非类特定特征的局部最优。一旦进入这种局部最优,后续训练难以逃出。

核心矛盾:标准SGD优化取所有域损失的平均梯度,当域间梯度方向冲突时,平均梯度指向的方向可能对所有域都不是最优的,导致模型学到域特定而非域不变的特征。

本文目标 在训练早期寻找一个参数空间中"所有域梯度方向一致"的起始点,使后续SGD优化自然趋向域不变特征的学习。

切入角度:借鉴模拟退火的随机搜索策略,在训练初期对参数施加均匀随机扰动,保留那些增加跨域梯度对齐度的扰动,拒绝降低对齐度的扰动。

核心 idea:训练早期用模拟退火搜索梯度跨域对齐的参数点,用该点作为后续标准SGD训练的起始状态,引导模型从一开始就学域不变特征。

方法详解

整体框架

三阶段:(1)预热阶段:标准SGD预热几步让模型有基本的梯度信号;(2)退火阶段:计算各源域的梯度对,对参数施加随机扰动\(\theta' \leftarrow \theta + \mathcal{U}(-\rho, \rho)\),如果新参数的最小域间梯度相似度\(>\theta\)的且损失增量可控(\(< 0.1\)),则接受扰动;迭代多次找到梯度对齐的参数点;(3)标准训练:从对齐点开始正常SGD训练。

关键设计

  1. 梯度对齐度量:

    • 功能:衡量当前参数点的跨域梯度一致性
    • 核心思路:计算所有源域对之间梯度的余弦相似度,取最小值作为整体对齐度:\(\text{grad\_sim} = \min_{i \neq j} \frac{g_i^T \cdot g_j}{\|g_i\| \cdot \|g_j\|}\)。最小值而非平均值确保了最差域对也要对齐,防止某域被忽略
    • 设计动机:如果最小的域对梯度相似度都很高,说明所有域的优化方向基本一致,此时SGD的平均梯度对每个域都有效
  2. 退火搜索策略:

    • 功能:在参数空间中搜索梯度对齐的点
    • 核心思路:对当前参数\(\theta\)加均匀随机扰动\(\mathcal{U}(-\rho, \rho)\)得到\(\theta'\),接受条件为:(a)\(\text{grad\_sim}(\theta') > \text{grad\_sim}(\theta)\)(对齐度增加)且(b)\(\mathcal{L}(\theta') - \mathcal{L}(\theta) < 0.1\)(损失增量可控)。不要求损失下降,允许暂时提高损失以换取更好的对齐
    • 设计动机:类似模拟退火允许暂时的"上坡"寻找全局最优,GGA允许损失暂时增加以找到梯度对齐的参数区域
  3. GGA-L轻量变体:

    • 功能:降低退火搜索的计算开销
    • 核心思路:借鉴SGLD(随机梯度Langevin动力学),在梯度更新时加入噪声来隐式搜索对齐点,避免显式的多次扰动-评估循环
    • 设计动机:完整GGA需要多次计算所有域的梯度,GGA-L通过噪声注入近似实现退火效果

损失函数 / 训练策略

退火阶段不改变损失函数,仅对参数施加随机扰动。退火窗口\([A_s, A_e]\)通常在前10-20%的训练迭代。超参数:\(\rho\)(扰动幅度)、\(n_a\)(每步退火迭代次数)。退火后切换回标准ERM/任何DG算法的训练。

实验关键数据

主实验

数据集 GGA (+ERM) ERM 其他SOTA 说明
PACS 提升显著 基线 竞争性 无需数据增强
VLCS 提升显著 基线 竞争性
OfficeHome 提升显著 基线 竞争性
TerraInc 提升显著 基线 竞争性

GGA的核心价值:可以即插即用地与任何DG方法组合,为其提供更好的初始化。

关键发现

  • 训练早期的域间梯度方向确实存在显著冲突,GGA能有效减少这种冲突
  • GGA不改变训练目标,仅改变出发点——所以能与任何DG方法正交组合
  • 损失容忍机制很关键——如果要求损失也下降,退火效果大幅减弱
  • GGA-L以更低计算成本实现了接近GGA的效果

亮点与洞察

  • "起点决定终点"的深刻洞察:训练初期几步的梯度方向决定了模型最终学到域不变还是域特定特征。这个发现本身就有重大理论意义
  • 方法的正交性:GGA只修改训练起始点,不修改损失/架构/增强策略,因此可以与任何DG方法叠加使用——这种"预处理"性质的方法实用性极高
  • 模拟退火在DL中的优雅应用:将经典优化方法用于解决梯度冲突的现代问题,简洁有效

局限与展望

  • 退火阶段需要对每个源域分别计算梯度,源域数量多时开销增加
  • 扰动幅度\(\rho\)和退火窗口\([A_s, A_e]\)需要调节
  • 目前只在分类DG上验证,回归/检测等DG任务未涉及
  • 理论分析不够充分——为什么梯度对齐的起点能保证收敛到域不变的解缺少严格证明

相关工作与启发

  • vs FISH/AND-Mask等梯度方法: 这些方法在整个训练过程中修改梯度方向;GGA仅在早期寻找好的起点,之后正常训练,更简洁
  • vs SAM(锐度感知最小化): SAM搜索平坦最优;GGA搜索梯度对齐区域,目标不同但搜索策略有相似性
  • vs 随机权重平均(SWA): SWA在训练后期平均参数;GGA在训练前期优化起始参数,时间窗口互补

评分

  • 新颖性: ⭐⭐⭐⭐ 梯度对齐搜索的思路新颖,模拟退火在DG中的应用独特
  • 实验充分度: ⭐⭐⭐ 五个标准DG基准,但具体数字需要更详细的报告
  • 写作质量: ⭐⭐⭐⭐ 理论动机分析深入,问题定义清晰
  • 价值: ⭐⭐⭐⭐ 即插即用的DG预处理方法,实用性强

相关论文