跳转至

Causality-Induced Positional Encoding for Transformer-Based Representation Learning of Non-Sequential Features

会议: NeurIPS 2025
arXiv: 2509.16629
代码: https://github.com/Catchxu/CAPE
领域: 因果推断 / Transformer
关键词: positional encoding, causal structure learning, hyperbolic embedding, rotary position encoding, multi-omics

一句话总结

CAPE 通过从表格数据中学习特征间的因果DAG结构,将其嵌入双曲空间生成因果感知的旋转位置编码(RoPE),使 Transformer 能处理非序列但因果相关的特征数据,在多组学数据的下游任务上显著提升性能。

研究背景与动机

  1. 领域现状:Transformer 中的位置编码(如 sinusoidal、RoPE 等)假设数据具有天然的序列顺序(词序、图像 patch 空间排列),在 NLP 和 CV 领域取得了巨大成功。
  2. 现有痛点:许多现实数据(如基因表达、蛋白质组学、经济指标)的特征之间没有预定义的序列顺序,但存在复杂的因果关系。现有位置编码方法无法捕捉这种非序列因果结构。
  3. 核心矛盾:现有方法要么忽略特征间的因果关系(如按表达量排序),要么使用静态的预训练嵌入作为伪位置编码,本质上都没有真正利用特征间的因果结构信息。
  4. 本文要解决什么? 如何为非序列但因果相关的特征生成位置编码,使 Transformer 的自注意力机制能感知因果关系?
  5. 切入角度:受狭义相对论启发——因果连接对应双曲时空中的相对位置——将因果图嵌入双曲空间,自然保留因果强度和因果特异性两个关键属性。
  6. 核心idea一句话:学习特征间因果DAG → 嵌入双曲空间 → 转换为旋转位置编码,使注意力分数随因果距离衰减。

方法详解

整体框架

CAPE 是一个三步框架:输入是 \(N \times M\) 的表格数据 \(\bm{X}\)\(N\) 个观测、\(M\) 个非序列特征),输出是每个特征 \(v_j\) 的因果感知旋转位置编码 \(\bm{\varphi}_{v_j}\)。三个步骤分别是:(1) 因果结构学习 → 得到加权DAG;(2) 双曲空间嵌入 → 保留因果属性;(3) 转换为旋转形式 → 注入 Transformer。

关键设计

  1. 因果结构学习 (Step I):
  2. 做什么:从观测数据 \(\bm{X}\) 中学习特征间的因果结构,表示为加权邻接矩阵 \(\bm{A}\)
  3. 核心思路:使用非线性结构方程模型(SEM),将其表述为 VAE 形式——编码器 \(\bm{Z} = f(\bm{X})(\bm{I} - \bm{A})\),解码器 \(\bm{X} = f^{-1}(\bm{Z}(\bm{I}-\bm{A})^{-1})\)。通过 ELBO + 稀疏正则 \(\|\bm{A}\|_1\) + DAG 无环约束 \(\text{tr}(e^{\bm{A} \odot \bm{A}}) - M = 0\) 联合优化
  4. 设计动机:非线性 SEM 能捕捉复杂因果关系;连续优化的无环约束避免了组合搜索;阈值 \(\tau\) 裁剪噪声边

  5. 双曲空间嵌入 (Step II):

  6. 做什么:将因果DAG嵌入双曲空间(hyperboloid model),为每个节点生成 \(d+1\) 维嵌入 \(\bm{p}_{v_j}\)
  7. 核心思路:通过正则化图对比学习优化嵌入位置。对比损失 \(\mathcal{L}_{\text{con}}\) 拉近因果相连的节点(正样本为 \(k\)-hop 邻居),推远无关节点,权重为 \(|\bm{A}_{mn}|\)(因果强度)。正则项 \(\Omega = \bm{\pi}_{v_m} d_l(\bm{p}_{v_m}, \bm{p_o})\) 利用 PageRank 值 \(\pi\) 惩罚因果通用节点(出度大的根节点),迫使其靠近原点
  8. 设计动机:双曲空间天然适合树状/DAG结构建模;两个关键属性被保留——因果强度(距离近 = 因果关系强)和因果特异性(离原点远 = 更具体的叶节点)。使用 Riemannian SGD 在流形上优化

  9. 旋转位置编码转换 (Step III):

  10. 做什么:将双曲嵌入映射到 Poincaré 球,再转为旋转形式注入 Transformer
  11. 核心思路:先通过微分同胚 \(f_d: \mathcal{H}^d \to \mathcal{B}^d\) 映射到 Poincaré 球得到 \(\bm{e}_{v_j}\),然后 \(\bm{\varphi}_v = c \cdot \bm{e}_v\)\(c=\pi/4\))作为旋转角度,构建块对角旋转矩阵 \(\bm{R}(\bm{\varphi}_v)\)。注意力计算变为 \(\mathcal{A} = (\bm{q}_{v_m}^i)^\top \bm{R}(\bm{\varphi}_{v_n} - \bm{\varphi}_{v_m}) \bm{k}_{v_n}^i\)
  12. 设计动机:Poincaré 球的球形几何更适合旋转编码;旋转形式兼容线性自注意力,且相对位置编码只依赖因果距离差 \(\bm{\varphi}_{v_n} - \bm{\varphi}_{v_m}\)

损失函数 / 训练策略

  • Step I: \(\mathcal{L}_{\text{DAG}} = -\mathcal{L}_{\text{ELBO}} + \lambda_s \|\bm{A}\|_1 + \frac{\rho}{2}|h(\bm{A})|^2 + \alpha h(\bm{A})\),用增广拉格朗日法求解
  • Step II: \(\mathcal{L}_{\mathcal{H}} = \frac{1}{M} \sum_j \mathcal{L}_{\text{con}}(\bm{p}_{v_j}) + \lambda_g \Omega(\bm{p}_{v_j})\),Riemannian SGD 优化
  • Step III: 无额外训练参数,直接通过映射转换

实验关键数据

主实验

在单细胞 scRNA-seq 数据集上评估基因扰动预测(GPP)任务,使用 scBERT 和 scGPT 两种 Transformer 基础模型:

模型 位置编码 单基因扰动 MSE 双基因扰动 MSE
scBERT 静态绝对PE (默认) 0.224 0.230
scBERT 可训练相对PE 0.219 (-0.005) 0.215 (-0.015)
scBERT CAPE 0.193 (-0.031) 0.189 (-0.041)
scGPT 可训练绝对PE (默认) 0.202 0.201
scGPT 可训练相对PE 0.195 (-0.007) 0.204 (+0.003)
scGPT CAPE 0.182 (-0.020) 0.176 (-0.025)

CAPE 平均降低 MSE 11.1%,而因果无关的相对PE仅降低 2.7%。

消融实验

以 scGPT 为骨干网络:

配置 单基因扰动 MSE 双基因扰动 MSE
CAPE (完整) 0.182 (±0.005) 0.176 (±0.008)
CAPE-null (无PE) 0.234 (±0.014) 0.238 (±0.017)
CAPE-w/o-CSL (无因果学习) 0.209 (±0.010) 0.213 (±0.011)
CAPE-w/o-hyperbolic (欧氏替代) 0.192 (±0.008) 0.196 (±0.008)
CAPE-w/o-rotary (加性PE) 0.201 (±0.009) 0.208 (±0.010)

关键发现

  • 因果结构学习(CSL)是最关键组件,去掉后双基因MSE从 0.176 升至 0.213(+21%)
  • 旋转形式也很重要,加性PE相比旋转PE性能差距明显
  • 双曲空间建模的贡献相对温和但仍然正向,说明曲率感知优化确实有助于更好反映因果图结构
  • 在合成数据上验证了理论属性:注意力随因果距离衰减、随因果通用性衰减、对位置扰动鲁棒

亮点与洞察

  • 因果图 → 双曲空间 → RoPE 的统一框架极为优雅,将因果发现、双曲几何和旋转注意力三个不同领域的方法有机结合。因果距离直接映射为注意力衰减这个设计非常自然
  • 理论分析扎实:证明了三个关键性质(因果距离衰减、因果通用性衰减、鲁棒性),为方法的有效性提供了数学保证
  • 可迁移的设计思路:任何涉及非序列特征(如推荐系统的特征交互、知识图谱节点、分子图中的原子)的 Transformer 应用都可以借鉴这种"结构→双曲嵌入→旋转编码"的范式

局限性 / 可改进方向

  • 因果结构学习部分假设了无环(DAG),对存在反馈环的因果系统不适用
  • 当前仅在生物组学数据上验证,缺少在经济、社会科学等其他非序列因果数据上的评估
  • 因果学习的准确性受数据量和噪声影响较大,在小样本高维场景下DAG学习可能不可靠
  • 计算复杂度:因果学习 + 双曲嵌入 + Transformer 三阶段训练可能较慢,但论文在附录中指出复杂度分析显示可接受

相关工作与启发

  • vs RoPE/绝对PE等标准方法: 这些方法假设预定义顺序,CAPE 扩展到无序因果特征
  • vs scGPT/scBERT 的默认PE: 它们用基于表达量排序或预训练表示作为伪PE,本质上忽略因果关系
  • vs NOTEARS/DAG-GNN: CAPE 借鉴了这些因果发现方法,但创新在于将学到的DAG转化为位置编码而非终点

评分

  • 新颖性: ⭐⭐⭐⭐⭐ 因果图→双曲嵌入→RoPE的组合非常新颖,跨领域的方法融合独具匠心
  • 实验充分度: ⭐⭐⭐⭐ 合成+多组学实验全面,但领域覆盖面可更广
  • 写作质量: ⭐⭐⭐⭐⭐ 数学推导严谨,叙述清晰,图示直观
  • 价值: ⭐⭐⭐⭐ 为非序列数据的Transformer建模开辟了新方向,但实际应用场景目前偏窄