跳转至

Brain-Semantoks: Learning Semantic Tokens of Brain Dynamics with a Self-Distilled Foundation Model

会议: ICLR2026
arXiv: 2512.11582
代码: https://github.com/SamGijsen/Brain-Semantoks
领域: medical_imaging
关键词: fMRI基础模型, 自蒸馏, 语义分词器, 脑动态表征学习, 线性探测

一句话总结

提出 Brain-Semantoks,一种基于语义分词器和自蒸馏目标的 fMRI 基础模型,将大脑功能网络聚合为鲁棒的语义 token,并通过跨时间视角的一致性学习抽象的脑动态表征,在线性探测设置下即可达到 SOTA 性能。

研究背景与动机

  1. 领域现状:fMRI 基础模型近年快速发展,BrainLM、Brain-JEPA、NeuroSTORM 等先驱工作均采用掩码-重建(mask-and-reconstruct)目标。这些方法专注于低层信号预测——BrainLM 直接在输入空间重建 BOLD 信号,Brain-JEPA 在潜空间做预测以避免噪声建模,NeuroSTORM 在 4D 体素上做时空重建。
  2. 核心矛盾:重建目标与下游任务目标之间存在根本性不匹配。下游任务(如疾病诊断、认知评估)需要的是稳定的、高层次的表型签名(phenotypic signature),而重建目标学到的表征对噪声和时间波动敏感,必须依赖大量微调才能适配。这种对微调的依赖削弱了基础模型的实用价值,尤其在 fMRI 领域——不同数据集在被试群体、硬件和采集协议上差异巨大。
  3. 关键假设:有效预测稳定表型需要从"重建"转向"抽象"——目标不是精确编码 BOLD 信号,而是从中提取底层的表型特征。
  4. 切入角度:(1) 单个 ROI 的时间序列噪声大、语义不明确,不适合作为 Transformer 的输入 token;(2) 大脑的功能组织(如默认模式网络等)提供了强有力的神经科学先验,可用于构建语义 token。
  5. 核心 idea:用功能网络级别的语义分词器将 noisy 的区域信号聚合为鲁棒 token,再通过自蒸馏目标学习跨时间稳定的抽象表征。

方法详解

整体框架

Brain-Semantoks 采用学生-教师架构。输入 fMRI 时间序列 \(X \in \mathbb{R}^{C \times T}\)(C=457 个脑区,T 为时间点数),首先裁剪为两个长时间片段作为不同视角,经语义分词器编码为功能网络级 token,再由 Transformer 编码器处理。教师网络权重为学生的指数移动平均(EMA),训练目标为跨视角的表征一致性。

关键设计

  1. 语义分词器(Semantic Tokenizer)
  2. 基于 Yeo 7-网络皮层分区 + 皮下 + 小脑共 9 个功能网络
  3. 每个网络有独立的分词模块 \(g_n\),处理该网络内所有 ROI 的时间序列
  4. 时间维度切分为 P 个较长的时间片,每个片段通过双分支卷积(标准卷积 + 结构化卷积)提取多尺度时间模式
  5. 输出 token 张量 \(Z \in \mathbb{R}^{N \times P \times D}\)(9 网络 × P 片段 × 768 维),最终序列长度仅 \(N \times P\)(远短于 457 个 ROI 的原始序列)
  6. 核心价值:将噪声大的 ROI 级信号聚合为语义丰富的网络级 token,为 Transformer 提供更好的输入

  7. 切片掩码(Slice Masking)

  8. 将 token 排列为 \(N \times P\) 的 2D 矩阵
  9. 随机选择两种策略之一:网络切片(掩码整行)或时间切片(掩码连续列块)
  10. 掩码比例高(\(\mathcal{U}[0.65, 0.85]\)),迫使模型学习网络间和跨时间的复杂关系
  11. 避免模型通过简单插值完成预测

  12. 三重损失函数

  13. \(\mathcal{L}_{CLS}\)(全局跨视角损失):学生和教师的 [CLS] token 在两个时间视角间做双向蒸馏,学习稳定的全局表征;使用 coding rate 正则化防止表征坍塌
  14. \(\mathcal{L}_{Tok}\)(网络 token 损失):在每个视角内,学生重建被掩码的网络 token 以匹配教师输出,学习时间敏感的局部特征
  15. \(\mathcal{L}_{TTR}\)(教师引导时间正则化):将每个网络的 P 个 patch 嵌入平均为单一 summary token,做跨视角蒸馏。在训练前 5% 步骤中激活后余弦衰减为零——引导模型先学习时间平均的网络签名,再建模更复杂的时间变化

  16. TTR 的设计动机

  17. 直接在低信噪比 fMRI 数据上应用自蒸馏目标容易导致训练不稳定和表征坍塌
  18. TTR 通过约束初始 token 空间(从 \(N \times P + 1\) 压缩到 \(N + 1\))帮助模型找到好的初始表征
  19. 仅在早期激活避免过度约束最终学到的解

训练细节

  • 时间裁剪长度 \(T_{crop}=100\),patch 长度 20,9 个功能网络,每网络 5 个 patch
  • Transformer:\(D_f=768\),8 层,投影头 2 隐层 \(D_h=1024\),输出 \(D_{proj}=128\)
  • 在 UKBioBank 39139 条静息态 fMRI 上预训练,单 GPU(<20GB 显存)不到 2 小时
  • Z-scoring 归一化替代 robust scaling,改善跨数据集迁移

实验关键数据

线性探测(冻结权重 + 单层线性层)

数据集/任务 BrainLM Brain-JEPA Brain-Semantoks
ABIDE (ASD) 53.84 52.92 65.13
HBN CELF 42.03 41.50 42.18
HBN WISC 38.26 38.34 40.87
UKB Sex 86.71 83.23 87.52
UKB Age 30.16 30.60 31.15
SRPBS SZ 57.61 57.63 69.26
SRPBS MDD 55.72 52.72 62.60
  • 在 9 个任务中 8 个取得最高分,临床诊断任务上优势尤为明显(ASD +12, SZ +12, MDD +7)

与监督方法的对比

  • 仅用线性探测即超越所有完全监督端到端训练的基线(FC、BNT、BolT、BrainMass)和微调后的 BrainLM/Brain-JEPA
  • 12 个任务上平均 52.72% vs 监督最优 50.68%,证明学到的表征无需微调即可广泛使用

任务态 fMRI 泛化(Hariri 情绪任务)

  • Brain-Semantoks 线性探测 93.84-96.50%,大幅超越 Brain-JEPA 的 81.06-82.29%
  • 利用掩码蒸馏框架加上 patch 构造策略解决预训练-推理时间尺度不匹配问题

Scaling Laws

  • 首次为 fMRI 基础模型提供详细的 scaling 分析
  • 线性探测性能随预训练数据量的对数呈幂律增长
  • OOD 任务上也观察到一致的 scaling 收益,且无性能平台
  • 即使 HBN 数据集与 UKB 有 >20 年的年龄差距,仍有持续的性能提升

消融实验关键发现

配置 平均分 说明
完整 Brain-Semantoks 52.39 语义分词器 + TTR + CLS + Tok
去掉 TTR(0%) 50.88 训练不稳定,性能下降 1.5
TTR 全程激活(100%) 49.60 过度约束,反而更差
去掉 CLS 损失 47.32 全局表征损失至关重要
线性投影替代语义分词器 差距大 导致部分坍塌,cosine 相似度迅速升到 0.95
随机掩码替代切片掩码 51.03 切片掩码减少插值学习

亮点与洞察

  • 从重建到抽象的范式转变:不同于此前所有 fMRI 基础模型的重建目标,Brain-Semantoks 明确以学习抽象表征为目标,这一范式转变带来了线性探测性能的巨大提升
  • 语义分词器的巧妙设计:将神经科学先验(功能网络)融入模型架构,既压缩了序列长度(457→45),又将噪声区域信号聚合为语义 token,类似于 NLP 中将字符聚合为词
  • TTR 课程学习:解决了低信噪比数据+自蒸馏目标组合的训练不稳定问题。只在前 5% 步骤中激活的设计精准平衡了稳定性和灵活性
  • Z-scoring > Robust Scaling:看似简单的归一化策略变化解决了跨数据集迁移时 DC offset 不一致的问题,是工程层面的重要贡献
  • 首个 fMRI scaling law 分析:证明了 OOD 性能随预训练数据量可靠提升,增强了社区对 fMRI 基础模型的信心

局限性 / 可改进方向

  • 功能网络划分依赖固定的 Yeo 7-网络分区,未来可探索从数据中学习 ROI 分组
  • 预训练仅使用静息态 fMRI(UKB),任务态数据的整合可能进一步提升表征质量
  • 下游评估中连续目标被离散化为多类标签,可能无法充分反映表征在回归任务上的能力
  • Scaling 分析虽然显示无平台,但受限于 UKB 数据量(~39K),更大规模的预训练效果未知
  • 单 GPU 训练效率很高(<2h),但 Transformer 8 层的容量是否限制了更复杂模式的学习尚未探讨

相关工作与启发

  • vs BrainLM/Brain-JEPA:两者采用 ROI 级掩码重建目标,Brain-Semantoks 在网络级做语义蒸馏,线性探测性能全面超越(尤其在 OOD 临床任务上差距超过 10%)
  • vs BrainMass:BrainMass 使用静态功能连接矩阵,忽略了时间动态;Brain-Semantoks 显式建模时间信息
  • vs DINO/iBOT/SimDINO:Brain-Semantoks 将视觉自蒸馏范式成功迁移到 fMRI,同时引入了领域特定的语义分词器和 TTR 稳定课程
  • 启发:功能网络作为语义分词器的思路可推广到其他脑成像模态(EEG、MEG);TTR 的"先学平均再学细节"策略可能对其他低 SNR 数据的自蒸馏有用

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 范式转变(重建→抽象),语义分词器和TTR都是原创设计
  • 实验充分度: ⭐⭐⭐⭐⭐ 11个下游任务×6个数据集,线性探测+微调+消融+scaling+可解释性,极为全面
  • 写作质量: ⭐⭐⭐⭐ 动机清晰,逻辑递进,消融系统化
  • 价值: ⭐⭐⭐⭐⭐ 为fMRI基础模型建立了新范式,线性探测即超监督方法,实用价值高