跳转至

IBDR: Promoting Ensemble Diversity with Interactive Bayesian Distributional Robustness

会议: ICML 2025
arXiv: 2506.07247
领域: 贝叶斯推断 / 集成学习 / 模型微调
关键词: 贝叶斯推断, 粒子多样性, 分布鲁棒性, SAM, DPP, LoRA

一句话总结

提出IBDR贝叶斯推断框架,通过在乘积分布空间上引入交互式损失和Wasserstein分布鲁棒性优化,构建兼顾多样性与低锐度的粒子集成,在VTAB-1K上以ViT-B/16实现73.6%平均准确率超越所有基线。

研究背景与动机

领域现状:贝叶斯推断通过多模型粒子采样提供不确定性估计。现有方法包括SGLD、SGHMC、SVGD等MCMC方法和变分推断(VI)方法,但在大规模模型上面临计算和存储开销。

现有痛点:(1) 传统贝叶斯推断中粒子从后验独立采样(\(\theta_{1:K} \sim^{iid} Q\)),缺乏显式交互机制;(2) 独立采样易导致粒子坍缩到同一mode,集成多样性不足;(3) SA-BNN等方法虽引入了尖锐度感知,但仍在独立分布空间上操作,未建模粒子间关系。

核心矛盾:如何在保持每个粒子模型低损失+低锐度的同时,显式促进粒子间的多样性?

切入角度:在乘积分布空间 \(Q^K = Q \odot Q \odot \cdots \odot Q\) 上定义包含交互项的联合损失,通过Wasserstein DRO理论推导出可操作的上界。

方法详解

整体框架

IBDR框架包含三个核心组件:(1) 定义包含分类损失+多样性损失的联合损失 \(\ell(\boldsymbol{\theta};x,y) = \frac{1}{K}\sum_i l(\theta_i;x,y) + \alpha l_{div}(\theta_{1:K};x,y)\);(2) 通过Theorem 4.1推导population loss的上界,连接分布鲁棒性与SAM;(3) 交替优化粒子均值 \(\mu_{1:K}\) 和对偶变量 \(\lambda\)

关键设计

  1. 交互式联合损失设计:

    • 功能:在乘积分布空间上定义包含粒子间交互项的损失函数
    • 核心思路:总损失 \(\ell(\boldsymbol{\theta};x,y) = \frac{1}{K}\sum_{i=1}^K l(\theta_i;x,y) + \alpha \cdot l_{div}(\theta_{1:K};x,y)\),其中 \(l_{div}\) 基于DPP定义为归一化非极大预测的Gram矩阵行列式:\(l_{div} = \det([\tilde{f}_{-y}^i]^T [\tilde{f}_{-y}^i])\),即最大化非真实类预测向量的张成体积
    • 设计动机:粒子在正确类上一致同意(低CE损失),但在"犯错方式"上多样化(高DPP体积),从而促进互补性
  2. 分布鲁棒性泛化界 (Theorem 4.1):

    • 功能:推导乘积分布空间上带交互损失的population loss上界
    • 核心思路:\(\mathcal{L}_\mathcal{D}(Q^K) \leq L\sqrt{\frac{K \cdot D_{KL}(Q,P) + \log(1/\delta)}{2N}} + \min_{\lambda \geq 0}\{\lambda\rho + \mathbb{E}_{\boldsymbol{\theta}\sim Q^K}[\max_{\boldsymbol{\theta}'}{\mathcal{L}_S(\boldsymbol{\theta}') - \lambda c^K(\boldsymbol{\theta},\boldsymbol{\theta}')}]\}\)。该上界将SAM(锐度感知最小化)和分布鲁棒性统一为特例(Corollary 4.2),并扩展到包含多样性交互项的乘积分布空间
    • 设计动机:直接最小化population loss不可行(无法访问 \(\mathcal{D}\)),需要可计算的上界;理论保证同时获得低损失、低锐度、高多样性
  3. 双层对偶优化算法:

    • 功能:基于理论上界设计可实现的优化流程
    • 核心思路:后验 \(Q = \frac{1}{K}\sum_i \mathcal{N}(\mu_i, \sigma^2 I)\)\(\sigma=0.1\)固定),交替:(1) 重参数化 \(\theta_i = \mu_i + \sigma\epsilon_i\) 后做一步梯度上升得到对抗扰动 \(\theta_i'\);(2) 在 \(\theta_i'\) 处计算包含多样性项的损失梯度更新 \(\mu_i\);(3) 投影梯度下降更新 \(\lambda \geq 0\)
    • 设计动机:类似SAM的两步优化但扩展到多粒子交互场景,\(\lambda\)自适应平衡鲁棒性半径与损失

损失函数 / 训练策略

使用LoRA微调ViT-B/16和LLaMA-2,每个粒子独立初始化LoRA模块。超参数:\(K\)个粒子(默认3-5),多样性权重 \(\alpha\),正则化权重 \(\beta\),对偶学习率 \(\alpha_\lambda\)。推理时各粒子独立预测后集成。

实验关键数据

主实验(VTAB-1K, ViT-B/16)

方法 Natural(7) Specialized(4) Structured(8) 平均
LoRA 79.2 84.3 60.2 68.4
SAM 79.5 83.2 53.4 70.5
SA-BNN 80.1 85.6 49.1 68.2
SGLD 78.4 83.6 57.3 68.4
SVGD 79.8 84.6 57.5 70.9
BayesTune 79.5 84.9 57.2 68.5
IBDR 81.5 86.0 60.3 73.6

ECE校准实验

方法 Natural Specialized Structured 平均ECE↓
LoRA 0.19 0.11 0.21 0.17
SAM 0.17 0.14 0.17 0.16
IBDR 0.11 0.10 0.14 0.12

关键发现

  • IBDR在19个VTAB子任务中13个达到最佳,平均准确率73.6%超越次佳SVGD的70.9(+2.7%)
  • Structured类别提升最显著(60.3 vs 次佳SGLD 57.3),说明多样性对复杂结构化任务更关键
  • ECE校准从LoRA的0.17降至0.12,说明IBDR改善了不确定性估计
  • 该框架在LLaMA-2 commonsense reasoning上也展示了提升

亮点与洞察

  • 理论贡献扎实——Theorem 4.1将SAM、DRO、粒子多样性统一为一个框架的不同特例
  • DPP多样性损失的设计直觉清晰:"正确答案上一致,犯错方式上多样"
  • \(w_d(0)\)的零初始化设计妙——确保单粒子退化时等价于标准LoRA
  • 理论与实践紧密对接——从Corollary 4.3直接推导出Algorithm 1

局限与展望

  • 存储开销随粒子数K线性增长,K>5在大模型上不实际
  • \(\sigma=0.1\)固定不学习,可能限制后验的灵活性
  • 多样性损失的计算需要所有粒子同时前向传播,训练时间约为单模型的K倍
  • 仅验证了LoRA微调场景,未探索Adapter/Prompt等其他PEFT方式

相关工作与启发

  • vs SA-BNN (NeurIPS23): SAM+贝叶斯但独立采样。IBDR在此基础上加入粒子交互,VTAB-1K平均准确率68.2→73.6(+5.4%)
  • vs SVGD: 通过核函数隐式推动粒子远离。IBDR通过DPP显式建模预测空间多样性,70.9→73.6
  • vs DeepEnsemble: 简单独立训练集成。IBDR通过交互损失显式促进多样性,67.0→73.6
  • 启示:贝叶斯推断中粒子交互是一个被忽视的维度,DPP多样性+DRO鲁棒性的组合有推广潜力

评分

  • 新颖性: ⭐⭐⭐⭐ 理论框架统一了SAM和DRO,交互式贝叶斯推断概念新颖
  • 实验充分度: ⭐⭐⭐⭐ VTAB-1K全19任务+ECE校准+LLaMA commonsense
  • 写作质量: ⭐⭐⭐⭐ 理论推导清晰,但符号较重
  • 价值: ⭐⭐⭐⭐ 对贝叶斯微调领域有理论和实践贡献

相关论文