跳转至

3DTeethSAM: Taming SAM2 for 3D Teeth Segmentation

会议: AAAI 2026
arXiv: 2512.11557
代码: https://github.com/Crisitofy/3DTeethSAM
领域: 3D视觉 / 医学图像分割
关键词: 3D牙齿分割, SAM2适配, 多视角渲染, 可变形注意力, 基础模型迁移

一句话总结

将SAM2基础模型迁移到3D牙齿分割任务,通过多视角渲染将3D mesh转为2D图像、设计三个轻量适配器(Prompt生成器、Mask精化器、Mask分类器)和可变形全局注意力插件(DGAP)来解决自动提示、边界精化和语义分类问题,在Teeth3DS上以91.90% T-mIoU刷新SOTA。

背景与动机

3D牙齿分割是数字化口腔的基础任务,需要将3D牙齿模型中的每颗牙齿实例定位并分类。现有方法主要依赖直接处理3D点云/mesh的专用网络(如PointNet++、MeshSegNet、TSGCNet等),存在两个核心瓶颈:(1) 这些从头训练的网络难以扩展到高分辨率3D模型;(2) 无法利用大规模预训练模型的知识。与此同时,SAM2作为2D视觉基础模型在各种下游任务上展现了强大的零样本能力,但将其迁移到3D牙齿分割面临维度不匹配、需要手动提示、类别不可知三大挑战。

核心问题

如何将SAM2这一2D基础模型有效适配到3D牙齿分割任务?具体需要解决:(1) SAM2依赖手动点/框提示,无法自动化;(2) SAM2的原始分割结果边界粗糙;(3) SAM2是类别无关的,无法区分不同牙齿ID。这三个问题共同阻碍了直接使用SAM2进行高精度、全自动的3D牙齿分割。

方法详解

整体框架

整个pipeline分为三步:(1) 多视角渲染:将3D牙齿mesh归一化后,从正面、背面及多个侧面等固定视角渲染为512×512的2D RGB图像;(2) SAM2适配分割:冻结SAM2预训练权重,通过三个轻量适配器和DGAP对2D图像进行分割,生成16通道mask(每通道对应一颗牙齿);(3) 2D→3D提升:反投影将2D分割结果映射回3D mesh顶点,对多视角结果进行投票聚合,最后用Graph Cut后处理修正边界。

关键设计

  1. Prompt Embedding Generator (PEG):借鉴DETR的思路,用Transformer Decoder将16个随机初始化的query向量转化为prompt embedding。自注意力建模牙齿间的位置关系,交叉注意力对齐图像特征。额外学习一个置信度分数来处理缺牙情况(值越高表示该牙齿实例存在的概率越大)。这完全取代了SAM2对手动提示的依赖。

  2. Mask Refiner:基于UNet架构的卷积网络,接收三路输入——原始牙齿图像(提供低级纹理/形状细节)、SAM2生成的粗略mask(提供空间先验)、SAM2的图像embedding(提供高级语义)。在UNet的收缩路径中,每层有三个并行流分别处理三路输入,然后拼接传递。这种设计专门解决SAM2通用预训练导致的边界不精确问题。

  3. Mask Classifier:同样采用Transformer Decoder架构(与PEG共享设计但独立参数),将16个query向量转化为类别概率向量。末端用MLP+Softmax输出17类概率(16颗牙+背景)。这比简单的"通道绑定牙齿ID"策略更鲁棒,避免了缺牙场景下的通道-ID错配问题。

  4. Deformable Global Attention Plugin (DGAP):集成到SAM2图像编码器Hiera trunk第3阶段的全局注意力块中。利用offset network预测偏移量来形变采样网格,使注意力集中在牙齿区域。与标准可变形注意力不同的是,query/key/value都从形变特征图预测,且通过skip connection融合形变和非形变特征。DGAP是即插即用模块,不修改SAM2内部实现。

损失函数 / 训练策略

  • 训练策略:冻结SAM2预训练权重,仅训练三个适配器和DGAP。使用匈牙利算法进行预测query与真值的一对一匹配。AdamW优化器,学习率2e-4,余弦退火+5 epoch warmup,训练100 epoch,batch size 4,混合精度。
  • 总损失\(L_{\text{total}} = \lambda_{MC} L_{MC} + \lambda_{PEG} L_{PEG} + \lambda_{MR} L_{MR}\),权重分别为1.0、1.0、2.0。
  • \(L_{MC}\):17类交叉熵损失(Mask Classifier)
  • \(L_{PEG}\):BCE + Dice + 置信度损失(Prompt Embedding Generator)
  • \(L_{MR}\):多类CE + Dice + 边界损失(Mask Refiner,边界损失用Sobel滤波器计算梯度的L1距离)

实验关键数据

数据集:Teeth3DS(1800个高分辨率口腔内3D扫描,900名患者,官方1200/600划分)

数据集 指标 本文 之前SOTA (ToothGroupNet) 提升
Teeth3DS OA 95.48% 95.19% +0.29%
Teeth3DS T-mIoU 91.90% 90.16% +1.74%
Teeth3DS B-IoU 70.05% 69.30% +0.75%
Teeth3DS Dice 94.33%
Teeth3DS 智齿T-mIoU (T8/16) 83.29% 68.20% +15.09%

消融实验要点

  • PEG是最关键模块:移除后T-mIoU暴跌39.44%(91.90%→52.46%)。即使用真值中心点作为手动提示,性能也远不如学习到的prompt embedding,说明PEG捕获了复杂的空间关系和上下文信息。
  • DGAP:移除后T-mIoU降1.29%,B-IoU降3.41%,且显著减慢了训练收敛速度。
  • Mask Refiner:移除后T-mIoU降0.80%,B-IoU降1.62%,主要影响边界质量。
  • Mask Classifier:移除后T-mIoU降0.59%,B-IoU降2.49%,主要解决相邻牙齿的类别混淆。

亮点

  • "渲染→2D分割→反投影"范式:将3D分割优雅地转化为2D问题,从而可以直接利用强大的2D基础模型,是一个通用且可复用的思路。
  • PEG的DETR式设计:用Transformer Decoder自动生成prompt embedding,完全绕过SAM2对手动提示的依赖,且建模了牙齿间的空间关系。
  • DGAP即插即用:不修改SAM2内部实现,通过skip connection融合形变/非形变特征,同时提升精度和训练效率,可推广到其他基础模型适配场景。
  • 智齿分割大幅提升:在稀有类别(智齿)上获得15%+的提升,展示了基础模型在数据稀缺场景下的优势。

局限性 / 可改进方向

  • 多视角渲染引入额外计算开销,推理效率可能不如直接处理3D数据的方法。
  • 仅在Teeth3DS一个数据集上验证,泛化性未知(不同扫描仪、不同种族的牙齿形态差异)。
  • 固定视角渲染可能遗漏某些角度的细节(如严重拥挤的牙齿),自适应视角选择可能更优。
  • 2D→3D的投票策略较简单,更精细的多视角融合方案(如可学习的融合权重)可能进一步提升。
  • 论文未讨论实时性和临床部署场景的可行性。

与相关工作的对比

  • vs ToothGroupNet:ToothGroupNet是之前SOTA,直接在3D mesh上操作。3DTeethSAM通过2D基础模型迁移的方式超越它,尤其在稀有类别上优势巨大(智齿+15%),但引入了多视角渲染的额外开销。
  • vs MedSAM:MedSAM将SAM适配到医学2D图像,但不处理3D数据。3DTeethSAM通过渲染→分割→反投影的pipeline解决了2D-3D维度不匹配问题。
  • vs 传统3D网络(PointNet++, DGCNN等):这些方法从头训练、难以利用预训练知识、在高分辨率mesh上扩展性差。3DTeethSAM冻结SAM2权重、仅训练轻量适配器,参数效率更高。

启发与关联

  • 通用3D分割范式:渲染→2D基础模型→反投影的思路可推广到其他3D医学分割任务(如骨骼、器官等),甚至非医学3D分割(如室内场景、自动驾驶点云)。
  • 自适应视角选择:当前固定视角,可以设计一个可学习的视角选择模块,根据mesh复杂度动态确定渲染视角。
  • 多基础模型融合:SAM2负责分割,可以引入其他基础模型(如DINOv2)提供更丰富的语义特征。
  • 端到端3D基础模型:当前方案通过2D中转,未来可否直接在3D空间训练类SAM的基础模型?

评分

  • 新颖性: ⭐⭐⭐⭐ 渲染+SAM2适配的思路有创新,但各模块(DETR式query、UNet refiner、可变形注意力)均有先例
  • 实验充分度: ⭐⭐⭐⭐ 消融实验详尽,11种方法对比,但仅一个数据集
  • 写作质量: ⭐⭐⭐⭐ 结构清晰,方法描述到位,图示直观
  • 价值: ⭐⭐⭐⭐ 展示了2D基础模型→3D分割的可行路径,对口腔数字化有实际意义,范式可推广