Ambiguity-aware Truncated Flow Matching for Ambiguous Medical Image Segmentation¶
会议: AAAI 2026
arXiv: 2511.06857v2
代码: https://github.com/PerceptionComputingLab/ATFM (有)
领域: medical_imaging
关键词: 模糊医学图像分割, 截断扩散模型, Flow Matching, 高斯截断表示, 语义一致性
一句话总结¶
提出 ATFM 框架,通过数据层级推理范式将预测精度和多样性解耦到分布级和样本级分别优化,结合高斯截断表示(GTR)和分割流匹配(SFM)两个模块,在模糊医学图像分割任务中同时提升预测的精度、保真度和多样性。
背景与动机¶
模糊医学图像分割(AMIS)旨在为一张医学图像生成多个合理的分割预测,以反映标注者之间的固有歧义。临床上,高多样性反映图像的固有模糊性,高精度则支撑可靠的诊断决策,两者缺一不可。
现有方法面临精度与多样性之间的固有权衡: - 随机方法(Prob. U-Net、PHiSeg):以牺牲精度换取多样性,导致低置信度诊断 - 多标注者感知方法:通过建模标注者风格提升两者,但会抑制低频模式,降低分割质量 - cVAE/扩散方法:注入随机性增强多样性,但单阶段推理使精度与多样性优化耦合
截断扩散概率模型(TDPMs)通过推理范式转变展现了潜力——在截断点 \(T_{\text{trunc}}\) 处截断扩散过程,用辅助网络估计截断点分布。但直接应用 TDPMs 仍有三个问题:(1) 统一推理目标无法解耦精度与多样性;(2) 基于采样的截断分布近似导致保真度不足;(3) 截断后缺乏语义引导影响预测合理性。
核心问题¶
如何在模糊医学图像分割中同时提升预测精度和多样性,打破现有方法中两者此消彼长的固有权衡?
方法详解¶
整体框架¶
ATFM 由三个核心组件构成,围绕重新定义的推理范式展开:
- Data-Hierarchical Inference(数据层级推理):重新定义 AMIS 专属推理范式,在数据分布级别提升精度,在数据样本级别增强多样性
- GTR(Gaussian Truncation Representation):在截断点显式建模高斯分布,提升预测保真度和截断分布可靠性
- SFM(Segmentation Flow Matching):引入语义感知的流变换,在增强多样性的同时保证预测合理性
关键设计¶
数据层级推理(Data-Hierarchical Inference)¶
核心思想是通过边缘化扩散过程中的随机性来解耦精度和多样性:
- 分布级别(截断步骤前):将多个标注样本 \(\{s_1, s_2, \dots, s_n\}\) 的随机性边缘化,监督截断点处的显式分布 \(P \sim \mathcal{N}(\mu, \Sigma)\) 逼近真实标注分布 \(Q \sim \mathcal{N}'(\mu', \Sigma')\),优化整体精度
- 样本级别(截断步骤后):从高保真的截断分布中采样,通过扩散过程生成多样化预测 \(\{pred_i\}_{i=1}^n\),并与真值 \(\{gt_i\}_{i=1}^n\) 逐样本监督,增强多样性
这种分层设计使得精度优化不会损害多样性,多样性增强又建立在全局对齐的分布基础上。
高斯截断表示(GTR)¶
GTR 用显式高斯建模替代传统 TDPMs 的采样近似:
- 通过分割骨干网络 \(f_\theta\) 提取图像语义特征 \(Z\)
- 用独立的卷积层 \(g_\phi\) 和 \(h_\psi\) 分别估计均值 \(\mu\) 和协方差 \(\Sigma\)(秩为 \(r=10\) 的低秩参数化)
- 截断点分布:\(X_{T_{\text{trunc}}} \sim \mathcal{N}(\mu, \Sigma)\)
理论支撑(Theorem 1 & 2): - 定理1:扩散任意时间步的边缘分布可参数化为 \(\mathcal{N}(\mu, DD^\top + L)\) - 定理2:任意高斯分布都存在某个时间步使扩散过程产生相同分布
因此 GTR 的高斯分布作为截断分布是合法且最优的选择。
网络结构:标准编码器-解码器架构,4 级分辨率,编码器滤波器大小 32→64→128→192,解码器使用转置卷积上采样和跳跃连接。
分割流匹配(SFM)¶
SFM 在截断后使用 Flow Matching 替代 DDPM,并引入语义一致性建模:
- Optimal Transport 调度:源分布 \(X_{T_{\text{trunc}}}\) 到目标 \(X_1\) 的最短路径线性插值
- \(X_t = t \times X_1 + (1-t) \times X_{T_{\text{trunc}}}\)
- ST-Net(Semantic-aware Transformation Network):时间条件 U-Net,预测流变换方向 \(g_\theta(X_t)\)
- 4 级编码器-解码器,15 个残差块
- 正弦时间步嵌入,通过 MLP 融入每个残差块
- 所有层使用线性注意力,瓶颈层使用全自注意力
- 中间预测:\(x_1^t = x_t + g_\theta(X_t) \times (1-t)\),通过解析几何在隐空间投影
- 语义一致性:在每个时间步 \(t\),计算中间预测 \(x_1^t\) 与所有真值标注之间的 Dice loss,显式建模状态-预测-真值之间的语义一致性
FM 相比 DDPM 的优势:避免高斯约束对细粒度预测的干扰。
损失函数 / 训练策略¶
两阶段训练:
- GTR 阶段:先训练 GTR 至收敛,然后冻结参数
- \(\mathcal{L}_{\text{Prior}} = -\log \int p(Y|X_{T_{\text{trunc}}}) p(X_{T_{\text{trunc}}}|X) dX_{T_{\text{trunc}}} \approx \frac{1}{M}\sum_{i=1}^M -\log p(Y|X_{T_{\text{trunc}}}^i)\)
-
Monte Carlo 采样 \(M=20\)
-
SFM 阶段:在冻结的 GTR 基础上训练 SFM
- \(\mathcal{L}_{\text{FM}}\):标准 Flow Matching 损失
- \(\mathcal{L}_{\text{SF}}\):语义一致性损失(Dice loss)
- \(\mathcal{L}_{\text{SFM}} = \mathcal{L}_{\text{FM}} + \alpha \cdot \mathcal{L}_{\text{SF}}\)
- \(\alpha\):LIDC 设为 \(10^{-3}\),ISIC3 设为 \(10^{-4}\)
训练配置: - 单张 RTX 3090 (24GB) - GTR:LIDC 1000 epochs / ISIC3 400 epochs - SFM:LIDC 200 epochs / ISIC3 120 epochs - Adam 优化器,学习率 \(10^{-4}\) - \(\lambda = 10^{-3}\)(\(T=1000\)),线性调度
实验关键数据¶
| 数据集 | 指标 | 本文 (ATFM) | 之前SOTA | 提升 |
|---|---|---|---|---|
| LIDC | GED₁₆↓ | 最佳 | CCDM / AB | 持续领先 |
| LIDC | GED₁₀₀↓ | 最佳 | runner-up | 11.5% |
| LIDC | HM-IoU₃₂↑ | 最佳 | runner-up | ≥7.3% |
| LIDC | MDM₃₂↑ | 最佳 | - | 领先 |
| ISIC3 | GED↓ | 最佳 | runner-up | 12% |
| ISIC3 | HM-IoU↑ | 最佳 | - | 最优 |
| ISIC3 | MDM↑ | 最佳 | - | 最优 |
推理效率(生成100个样本):
| 方法 | 步数 | 时间 |
|---|---|---|
| CIMD | 多步 | 420s |
| AB | 多步 | 1050s |
| CCDM | 多步 | 1100s |
| ATFM | GTR + 25步 | 113s |
ATFM 推理速度远快于其他扩散方法,仅需 25 步扩散 + 一次 GTR 估计。
消融实验要点¶
在 LIDC 数据集上的五种变体对比:
| 变体 | 说明 | 结论 |
|---|---|---|
| Act. GTR | 仅用 GTR+激活层 | 基线 |
| SFM w/o \(\mathcal{L}_{\text{SF}}\) | SFM 无语义损失 | 多样性不足 |
| SFM | 完整 SFM | 单独 SFM 效果好 |
| ATFM w/o \(\mathcal{L}_{\text{SF}}\) | 完整框架但无语义损失 | 合理性下降 |
| ATFM | 完整框架 | 最佳 |
- ATFM 比单独 Act. GTR 和 SFM 分别提升 ≥10% 和 ≥6%,验证了数据层级推理和两模块协同的有效性
- 有无 \(\mathcal{L}_{\text{SF}}\) 的平均差距达 11%,说明语义一致性建模至关重要
- 推理步数在 25 步时达到性能与效率的最佳平衡
- \(\alpha\) 过小限制 \(\mathcal{L}_{\text{SF}}\) 作用,过大削弱 \(\mathcal{L}_{\text{FM}}\)
亮点¶
- 推理范式创新:首次将 TDPMs 的推理范式重新定义为 AMIS 专属的数据层级推理,从根本上解耦精度与多样性
- 理论完备:Theorem 1 & 2 为 GTR 的高斯建模提供了严格的理论保证
- 三模块协同设计:Data-Hierarchical Inference + GTR + SFM 各解决一个特定问题,设计逻辑清晰
- 效率优势显著:113s vs 420-1100s,推理速度提升 3.7-9.7 倍,同时性能更优
- FM 替代 DDPM:避免高斯约束对细粒度分割的干扰,是合理的技术选择
局限性 / 可改进方向¶
- 数据集规模有限:仅在 LIDC(肺部CT)和 ISIC3 子集(皮肤病变,仅300张图片)上验证,未覆盖更多模态和器官
- 标注数量固定:LIDC 4个标注、ISIC3 3个标注,未探讨标注数量变化对方法的影响
- 两阶段训练:GTR 需先训练至收敛再冻结,流程较复杂,端到端联合训练可能更优
- 表格数值缺失:HTML论文中的量化表格数值未能完整显示,影响具体对比
- 3D 拓展:当前仅在 2D 切片上实验,3D 体积分割场景的适用性待验证
与相关工作的对比¶
| 方法类别 | 代表工作 | 局限 | ATFM 优势 |
|---|---|---|---|
| 模型集成 / 多头 | SSN, MoSE | 不改变推理过程,受限于模型选择 | 重新定义推理范式 |
| cVAE 方法 | Prob. U-Net, PHiSeg | 单阶段推理,精度多样性耦合 | 分层解耦两个目标 |
| 扩散方法 | CIMD, CCDM | 高斯约束干扰细粒度预测,推理慢 | FM 避免高斯限制,25步高效推理 |
| 多标注者感知 | c-Prob. U-Net, c-SSN | 抑制低频模式 | GTR 显式建模保留低频模式 |
| 传统 TDPMs | TDPM | 采样近似不精确,缺乏语义引导 | GTR 显式建模 + SFM 语义监督 |
启发与关联¶
- 推理范式重定义的思路:不改变模型结构而重新定义推理过程的思想可迁移到其他需要平衡多个目标的生成任务
- 截断扩散的通用性:TDPMs 的"截断+替换中间推理路径"揭示了扩散模型推理的结构灵活性,可应用于其他条件生成任务
- Flow Matching in 分割:将 FM 引入分割任务是新趋势,避免 DDPM 的高斯假设对 one-hot 标签空间的不适配
- 显式分布建模 vs 隐式采样:GTR 的显式高斯建模思路可推广到其他需要在中间表示处建模分布的场景
评分¶
- 新颖性: ⭐⭐⭐⭐ — 数据层级推理范式是有意义的创新,三模块设计逻辑清晰
- 技术深度: ⭐⭐⭐⭐ — 有理论保证(两个定理),方法推导完整
- 实验充分性: ⭐⭐⭐ — 仅两个数据集,但消融和超参分析全面
- 写作质量: ⭐⭐⭐⭐ — 结构清晰,各模块的"Summarized Advantage"有助理解
- 影响力: ⭐⭐⭐⭐ — 在 AMIS 领域提供了新的范式思路,FM 在医学分割中的应用有参考价值