跳转至

Compositional amortized inference for large-scale hierarchical Bayesian models

会议: ICLR2026
arXiv: 2505.14429
代码: 待确认
领域: 优化/理论
关键词: amortized Bayesian inference, hierarchical model, compositional score matching, diffusion model, scalability

一句话总结

将组合分数匹配(CSM)扩展到层次贝叶斯模型,通过新的误差衰减估计器和 mini-batch 策略解决大量数据组下的数值不稳定问题,首次实现超过 75 万参数(25 万+ 数据组)的大规模层次模型的摊销推断,并在荧光寿命成像的真实科学应用中验证有效性。

研究背景与动机

  1. 领域现状:摊销贝叶斯推断(ABI)用神经网络学习通用后验函数,训练后对新数据零延迟采样。但扩展到层次模型是主要障碍——层次模型的训练需要为每个 batch 模拟完整的"数据集套数据集",计算成本巨大。
  2. 现有痛点:(a) 直接 ABI 需要模拟 \(J\) 组数据(\(J\) 可达数十万),每个 batch 需要上万次模拟;(b) 现有 CSM 方法在组数 \(J > 100\) 时数值不稳定——组合分数累加导致误差复合化,采样发散;(c) MCMC(如 NUTS)对大规模层次模型也不可行。
  3. 核心矛盾:CSM 的分治策略让训练只需单组数据模拟(高效),但推断时需要组合 \(J\) 个分数估计(\(J\) 大时数值爆炸)。数据组越多→分数累加项越多→近似误差复合越严重。
  4. 本文要解决什么? 让 CSM 在数十万组数据规模下保持数值稳定,实现大规模层次模型的摊销推断。
  5. 切入角度:引入误差衰减桥接密度——在高噪声区间降低组合分数的影响力,在低噪声区间恢复完整累加,同时用 mini-batch 估计器解决内存问题。
  6. 核心idea一句话:通过时变衰减函数 \(d(t)\) 调制组合分数在扩散轨迹中的累积速度——高噪声时衰减(防发散)、低噪声时恢复(保正确性)。

方法详解

整体框架

层次模型:\(\mathbf{Y}_j \sim p(\mathbf{Y}_j | \boldsymbol{\theta}_j, \boldsymbol{\eta})\), \(\boldsymbol{\theta}_j \sim p(\boldsymbol{\theta} | \boldsymbol{\eta})\), \(\boldsymbol{\eta} \sim p(\boldsymbol{\eta})\)。训练两个分数网络:局部 \(s^{\text{local}}(\boldsymbol{\theta}_t, \boldsymbol{\eta}, \mathbf{Y}_j)\) 和全局 \(s^{\text{global}}(\boldsymbol{\eta}_t, \mathbf{Y}_j)\)。训练只需单组数据模拟。推断时用误差衰减 CSM 组合全局分数→采样 \(\boldsymbol{\eta}\)→条件采样各组 \(\boldsymbol{\theta}_j\)

关键设计

  1. SDE 采样器替代 Langevin 采样:
  2. 做什么:用自适应步长 SDE 求解器替代固定步长的退火 Langevin 采样
  3. 核心思路:利用反向 SDE 公式 + 自适应求解器自动调整步长
  4. 设计动机:Langevin 采样需要大量步数且对步长敏感;自适应求解器在高噪声区自动缩小步长、低噪声区加大步长

  5. 误差衰减桥接密度:

  6. 做什么:引入时变衰减函数 \(d(t)\) 调制组合分数
  7. 公式:\(p_t(\boldsymbol{\eta}_t | \mathbf{Y}_{1:J}) \propto p(\boldsymbol{\eta}_t)^{(1-J)(1-t)d(t)} \prod_j p_t(\boldsymbol{\eta}_t | \mathbf{Y}_j)^{d(t)}\)
  8. 约束:\(d(0)=1\)(低噪声时恢复真实后验),\(d(1) \leq 1\)(高噪声时衰减)
  9. 衰减调度:指数衰减 \(d(t) = \exp(-\ln(1/d_1) \cdot t)\)\(d_1\) 为可调超参数
  10. 设计动机:观察到自适应求解器在高噪声区需要极小步长→分数误差在高噪声区最严重→衰减高噪声区的组合分数贡献

  11. Mini-batch 组合分数估计器:

  12. 做什么:用随机子集替代全部 \(J\) 组的分数累加
  13. 公式:\(\hat{s}(\boldsymbol{\eta}_t) = (1-J)(1-t)\nabla \log p(\boldsymbol{\eta}_t) + \frac{J}{M}\sum_{i=1}^M s(\boldsymbol{\eta}_t, \mathbf{Y}_{j_i})\)
  14. 性质:无偏估计器(Proposition 3.1 证明)
  15. 设计动机:\(J > 10000\) 时全量累加的内存和计算不可行。Mini-batch 引入方差但与衰减结合后方差被控制

  16. 噪声调度调整:

  17. 做什么:推断时使用与训练不同的噪声调度,压缩高噪声区间
  18. 核心思路:增大 cosine 调度的 shift 参数 \(s\),减少在高噪声区的采样步数
  19. 设计动机:高噪声区是误差累积最严重的区间,减少在此区间的停留时间

损失函数 / 训练策略

联合训练全局和局部分数模型(Eq. 11),使用去噪分数匹配 + likelihood weighting。训练时只需模拟单组数据→simulation 效率极高。

实验关键数据

主实验(不同方法的收敛性)

方法 N=10 N=100 N=10K N=100K
Annealed Langevin
Euler-Maruyama
Probability ODE
GAUSS
Ours (damping)

消融实验(层次 AR 模型)

配置 全局参数 RMSE 局部参数 RMSE 说明
直接 ABI(小规模) 最优 最优 需要完整模拟
CSM(无衰减) 发散 - \(J > 100\) 失败
CSM + 衰减 接近直接 ABI 接近直接 ABI 模拟效率 <<1 次完整模拟

关键发现

  • 现有 CSM 方法全部在 \(N > 100\) 时失败:Langevin、Euler-Maruyama、ODE、GAUSS 无一能处理超过 10K 数据点
  • 误差衰减是解锁大规模的关键:有衰减后 100K 数据点也能稳定收敛
  • 真实应用验证:荧光寿命成像中 \(J > 250,000\) 组、\(> 750,000\) 参数——首次实现如此规模的摊销层次推断
  • 训练效率极高:不需要模拟完整层次数据集,只需单组模拟→训练成本不随 \(J\) 增长

亮点与洞察

  • 分治思想的贝叶斯推断实例:训练时分解为单组→推断时组合——避免了"数据集套数据集"的模拟爆炸
  • 时变衰减的优雅设计\(d(0)=1, d(1)\leq 1\) 确保低噪声时后验正确、高噪声时数值稳定——无偏性+稳定性的精妙平衡
  • Mini-batch 无偏性:随机子集估计的无偏性证明简洁,与衰减组合后方差可控
  • 首次突破 CSM 的规模瓶颈:从 100 数据点 → 100,000 数据点,三个数量级的突破

局限性 / 可改进方向

  • 衰减参数 \(d_1\) 需要调整:虽然可在推断时调,但最优值依赖问题规模。自适应 \(d_1\) 选择是改进方向
  • 分数网络对每组独立训练:不同组之间的信息共享有限。层间信息传递可能进一步提高效率
  • 仅验证了两层层次模型:更深的层次结构(3 层+)需要递归组合,稳定性未验证
  • mini-batch 引入方差:虽然无偏但方差随 \(J/M\) 增大。自适应 mini-batch 大小可能有帮助

相关工作与启发

  • vs Geffner et al. (2023) CSM:原始 CSM 用退火 Langevin 采样,\(N > 10\) 就失败。本文用 SDE + 衰减突破到 \(100K+\)
  • vs Linhart et al. (GAUSS):用二阶高斯近似,限制在 100 观测点。本文通过衰减 + mini-batch 实现三个数量级的扩展
  • vs 直接 ABI(Habermann/Heinrich):直接 ABI 需要完整层次模拟。本文只需单组模拟,在大 \(J\) 下计算优势巨大

评分

  • 新颖性: ⭐⭐⭐⭐ 误差衰减桥接密度 + mini-batch 组合估计器的设计新颖且有理论支撑
  • 实验充分度: ⭐⭐⭐⭐ 从高斯玩具→层次 AR→真实荧光成像三层验证,但缺少与更多基线的对比
  • 写作质量: ⭐⭐⭐⭐ 数学推导严谨,从不稳定性分析→解决方案的叙事清晰
  • 价值: ⭐⭐⭐⭐⭐ 首次使层次贝叶斯的摊销推断扩展到实际科学应用规模(75 万参数)