Compositional amortized inference for large-scale hierarchical Bayesian models¶
会议: ICLR2026
arXiv: 2505.14429
代码: 待确认
领域: 优化/理论
关键词: amortized Bayesian inference, hierarchical model, compositional score matching, diffusion model, scalability
一句话总结¶
将组合分数匹配(CSM)扩展到层次贝叶斯模型,通过新的误差衰减估计器和 mini-batch 策略解决大量数据组下的数值不稳定问题,首次实现超过 75 万参数(25 万+ 数据组)的大规模层次模型的摊销推断,并在荧光寿命成像的真实科学应用中验证有效性。
研究背景与动机¶
- 领域现状:摊销贝叶斯推断(ABI)用神经网络学习通用后验函数,训练后对新数据零延迟采样。但扩展到层次模型是主要障碍——层次模型的训练需要为每个 batch 模拟完整的"数据集套数据集",计算成本巨大。
- 现有痛点:(a) 直接 ABI 需要模拟 \(J\) 组数据(\(J\) 可达数十万),每个 batch 需要上万次模拟;(b) 现有 CSM 方法在组数 \(J > 100\) 时数值不稳定——组合分数累加导致误差复合化,采样发散;(c) MCMC(如 NUTS)对大规模层次模型也不可行。
- 核心矛盾:CSM 的分治策略让训练只需单组数据模拟(高效),但推断时需要组合 \(J\) 个分数估计(\(J\) 大时数值爆炸)。数据组越多→分数累加项越多→近似误差复合越严重。
- 本文要解决什么? 让 CSM 在数十万组数据规模下保持数值稳定,实现大规模层次模型的摊销推断。
- 切入角度:引入误差衰减桥接密度——在高噪声区间降低组合分数的影响力,在低噪声区间恢复完整累加,同时用 mini-batch 估计器解决内存问题。
- 核心idea一句话:通过时变衰减函数 \(d(t)\) 调制组合分数在扩散轨迹中的累积速度——高噪声时衰减(防发散)、低噪声时恢复(保正确性)。
方法详解¶
整体框架¶
层次模型:\(\mathbf{Y}_j \sim p(\mathbf{Y}_j | \boldsymbol{\theta}_j, \boldsymbol{\eta})\), \(\boldsymbol{\theta}_j \sim p(\boldsymbol{\theta} | \boldsymbol{\eta})\), \(\boldsymbol{\eta} \sim p(\boldsymbol{\eta})\)。训练两个分数网络:局部 \(s^{\text{local}}(\boldsymbol{\theta}_t, \boldsymbol{\eta}, \mathbf{Y}_j)\) 和全局 \(s^{\text{global}}(\boldsymbol{\eta}_t, \mathbf{Y}_j)\)。训练只需单组数据模拟。推断时用误差衰减 CSM 组合全局分数→采样 \(\boldsymbol{\eta}\)→条件采样各组 \(\boldsymbol{\theta}_j\)。
关键设计¶
- SDE 采样器替代 Langevin 采样:
- 做什么:用自适应步长 SDE 求解器替代固定步长的退火 Langevin 采样
- 核心思路:利用反向 SDE 公式 + 自适应求解器自动调整步长
-
设计动机:Langevin 采样需要大量步数且对步长敏感;自适应求解器在高噪声区自动缩小步长、低噪声区加大步长
-
误差衰减桥接密度:
- 做什么:引入时变衰减函数 \(d(t)\) 调制组合分数
- 公式:\(p_t(\boldsymbol{\eta}_t | \mathbf{Y}_{1:J}) \propto p(\boldsymbol{\eta}_t)^{(1-J)(1-t)d(t)} \prod_j p_t(\boldsymbol{\eta}_t | \mathbf{Y}_j)^{d(t)}\)
- 约束:\(d(0)=1\)(低噪声时恢复真实后验),\(d(1) \leq 1\)(高噪声时衰减)
- 衰减调度:指数衰减 \(d(t) = \exp(-\ln(1/d_1) \cdot t)\),\(d_1\) 为可调超参数
-
设计动机:观察到自适应求解器在高噪声区需要极小步长→分数误差在高噪声区最严重→衰减高噪声区的组合分数贡献
-
Mini-batch 组合分数估计器:
- 做什么:用随机子集替代全部 \(J\) 组的分数累加
- 公式:\(\hat{s}(\boldsymbol{\eta}_t) = (1-J)(1-t)\nabla \log p(\boldsymbol{\eta}_t) + \frac{J}{M}\sum_{i=1}^M s(\boldsymbol{\eta}_t, \mathbf{Y}_{j_i})\)
- 性质:无偏估计器(Proposition 3.1 证明)
-
设计动机:\(J > 10000\) 时全量累加的内存和计算不可行。Mini-batch 引入方差但与衰减结合后方差被控制
-
噪声调度调整:
- 做什么:推断时使用与训练不同的噪声调度,压缩高噪声区间
- 核心思路:增大 cosine 调度的 shift 参数 \(s\),减少在高噪声区的采样步数
- 设计动机:高噪声区是误差累积最严重的区间,减少在此区间的停留时间
损失函数 / 训练策略¶
联合训练全局和局部分数模型(Eq. 11),使用去噪分数匹配 + likelihood weighting。训练时只需模拟单组数据→simulation 效率极高。
实验关键数据¶
主实验(不同方法的收敛性)¶
| 方法 | N=10 | N=100 | N=10K | N=100K |
|---|---|---|---|---|
| Annealed Langevin | ✓ | ✗ | ✗ | ✗ |
| Euler-Maruyama | ✓ | ✗ | ✗ | ✗ |
| Probability ODE | ✓ | ✓ | ✗ | ✗ |
| GAUSS | ✓ | ✓ | ✗ | ✗ |
| Ours (damping) | ✓ | ✓ | ✓ | ✓ |
消融实验(层次 AR 模型)¶
| 配置 | 全局参数 RMSE | 局部参数 RMSE | 说明 |
|---|---|---|---|
| 直接 ABI(小规模) | 最优 | 最优 | 需要完整模拟 |
| CSM(无衰减) | 发散 | - | \(J > 100\) 失败 |
| CSM + 衰减 | 接近直接 ABI | 接近直接 ABI | 模拟效率 <<1 次完整模拟 |
关键发现¶
- 现有 CSM 方法全部在 \(N > 100\) 时失败:Langevin、Euler-Maruyama、ODE、GAUSS 无一能处理超过 10K 数据点
- 误差衰减是解锁大规模的关键:有衰减后 100K 数据点也能稳定收敛
- 真实应用验证:荧光寿命成像中 \(J > 250,000\) 组、\(> 750,000\) 参数——首次实现如此规模的摊销层次推断
- 训练效率极高:不需要模拟完整层次数据集,只需单组模拟→训练成本不随 \(J\) 增长
亮点与洞察¶
- 分治思想的贝叶斯推断实例:训练时分解为单组→推断时组合——避免了"数据集套数据集"的模拟爆炸
- 时变衰减的优雅设计:\(d(0)=1, d(1)\leq 1\) 确保低噪声时后验正确、高噪声时数值稳定——无偏性+稳定性的精妙平衡
- Mini-batch 无偏性:随机子集估计的无偏性证明简洁,与衰减组合后方差可控
- 首次突破 CSM 的规模瓶颈:从 100 数据点 → 100,000 数据点,三个数量级的突破
局限性 / 可改进方向¶
- 衰减参数 \(d_1\) 需要调整:虽然可在推断时调,但最优值依赖问题规模。自适应 \(d_1\) 选择是改进方向
- 分数网络对每组独立训练:不同组之间的信息共享有限。层间信息传递可能进一步提高效率
- 仅验证了两层层次模型:更深的层次结构(3 层+)需要递归组合,稳定性未验证
- mini-batch 引入方差:虽然无偏但方差随 \(J/M\) 增大。自适应 mini-batch 大小可能有帮助
相关工作与启发¶
- vs Geffner et al. (2023) CSM:原始 CSM 用退火 Langevin 采样,\(N > 10\) 就失败。本文用 SDE + 衰减突破到 \(100K+\)
- vs Linhart et al. (GAUSS):用二阶高斯近似,限制在 100 观测点。本文通过衰减 + mini-batch 实现三个数量级的扩展
- vs 直接 ABI(Habermann/Heinrich):直接 ABI 需要完整层次模拟。本文只需单组模拟,在大 \(J\) 下计算优势巨大
评分¶
- 新颖性: ⭐⭐⭐⭐ 误差衰减桥接密度 + mini-batch 组合估计器的设计新颖且有理论支撑
- 实验充分度: ⭐⭐⭐⭐ 从高斯玩具→层次 AR→真实荧光成像三层验证,但缺少与更多基线的对比
- 写作质量: ⭐⭐⭐⭐ 数学推导严谨,从不稳定性分析→解决方案的叙事清晰
- 价值: ⭐⭐⭐⭐⭐ 首次使层次贝叶斯的摊销推断扩展到实际科学应用规模(75 万参数)