Causality-Induced Positional Encoding for Transformer-Based Representation Learning of Non-Sequential Features¶
会议: NeurIPS 2025
arXiv: 2509.16629
代码: https://github.com/Catchxu/CAPE
领域: 因果推断 / Transformer
关键词: positional encoding, causal structure learning, hyperbolic embedding, rotary position encoding, multi-omics
一句话总结¶
CAPE 通过从表格数据中学习特征间的因果DAG结构,将其嵌入双曲空间生成因果感知的旋转位置编码(RoPE),使 Transformer 能处理非序列但因果相关的特征数据,在多组学数据的下游任务上显著提升性能。
研究背景与动机¶
- 领域现状:Transformer 中的位置编码(如 sinusoidal、RoPE 等)假设数据具有天然的序列顺序(词序、图像 patch 空间排列),在 NLP 和 CV 领域取得了巨大成功。
- 现有痛点:许多现实数据(如基因表达、蛋白质组学、经济指标)的特征之间没有预定义的序列顺序,但存在复杂的因果关系。现有位置编码方法无法捕捉这种非序列因果结构。
- 核心矛盾:现有方法要么忽略特征间的因果关系(如按表达量排序),要么使用静态的预训练嵌入作为伪位置编码,本质上都没有真正利用特征间的因果结构信息。
- 本文要解决什么? 如何为非序列但因果相关的特征生成位置编码,使 Transformer 的自注意力机制能感知因果关系?
- 切入角度:受狭义相对论启发——因果连接对应双曲时空中的相对位置——将因果图嵌入双曲空间,自然保留因果强度和因果特异性两个关键属性。
- 核心idea一句话:学习特征间因果DAG → 嵌入双曲空间 → 转换为旋转位置编码,使注意力分数随因果距离衰减。
方法详解¶
整体框架¶
CAPE 是一个三步框架:输入是 \(N \times M\) 的表格数据 \(\bm{X}\)(\(N\) 个观测、\(M\) 个非序列特征),输出是每个特征 \(v_j\) 的因果感知旋转位置编码 \(\bm{\varphi}_{v_j}\)。三个步骤分别是:(1) 因果结构学习 → 得到加权DAG;(2) 双曲空间嵌入 → 保留因果属性;(3) 转换为旋转形式 → 注入 Transformer。
关键设计¶
- 因果结构学习 (Step I):
- 做什么:从观测数据 \(\bm{X}\) 中学习特征间的因果结构,表示为加权邻接矩阵 \(\bm{A}\)
- 核心思路:使用非线性结构方程模型(SEM),将其表述为 VAE 形式——编码器 \(\bm{Z} = f(\bm{X})(\bm{I} - \bm{A})\),解码器 \(\bm{X} = f^{-1}(\bm{Z}(\bm{I}-\bm{A})^{-1})\)。通过 ELBO + 稀疏正则 \(\|\bm{A}\|_1\) + DAG 无环约束 \(\text{tr}(e^{\bm{A} \odot \bm{A}}) - M = 0\) 联合优化
-
设计动机:非线性 SEM 能捕捉复杂因果关系;连续优化的无环约束避免了组合搜索;阈值 \(\tau\) 裁剪噪声边
-
双曲空间嵌入 (Step II):
- 做什么:将因果DAG嵌入双曲空间(hyperboloid model),为每个节点生成 \(d+1\) 维嵌入 \(\bm{p}_{v_j}\)
- 核心思路:通过正则化图对比学习优化嵌入位置。对比损失 \(\mathcal{L}_{\text{con}}\) 拉近因果相连的节点(正样本为 \(k\)-hop 邻居),推远无关节点,权重为 \(|\bm{A}_{mn}|\)(因果强度)。正则项 \(\Omega = \bm{\pi}_{v_m} d_l(\bm{p}_{v_m}, \bm{p_o})\) 利用 PageRank 值 \(\pi\) 惩罚因果通用节点(出度大的根节点),迫使其靠近原点
-
设计动机:双曲空间天然适合树状/DAG结构建模;两个关键属性被保留——因果强度(距离近 = 因果关系强)和因果特异性(离原点远 = 更具体的叶节点)。使用 Riemannian SGD 在流形上优化
-
旋转位置编码转换 (Step III):
- 做什么:将双曲嵌入映射到 Poincaré 球,再转为旋转形式注入 Transformer
- 核心思路:先通过微分同胚 \(f_d: \mathcal{H}^d \to \mathcal{B}^d\) 映射到 Poincaré 球得到 \(\bm{e}_{v_j}\),然后 \(\bm{\varphi}_v = c \cdot \bm{e}_v\)(\(c=\pi/4\))作为旋转角度,构建块对角旋转矩阵 \(\bm{R}(\bm{\varphi}_v)\)。注意力计算变为 \(\mathcal{A} = (\bm{q}_{v_m}^i)^\top \bm{R}(\bm{\varphi}_{v_n} - \bm{\varphi}_{v_m}) \bm{k}_{v_n}^i\)
- 设计动机:Poincaré 球的球形几何更适合旋转编码;旋转形式兼容线性自注意力,且相对位置编码只依赖因果距离差 \(\bm{\varphi}_{v_n} - \bm{\varphi}_{v_m}\)
损失函数 / 训练策略¶
- Step I: \(\mathcal{L}_{\text{DAG}} = -\mathcal{L}_{\text{ELBO}} + \lambda_s \|\bm{A}\|_1 + \frac{\rho}{2}|h(\bm{A})|^2 + \alpha h(\bm{A})\),用增广拉格朗日法求解
- Step II: \(\mathcal{L}_{\mathcal{H}} = \frac{1}{M} \sum_j \mathcal{L}_{\text{con}}(\bm{p}_{v_j}) + \lambda_g \Omega(\bm{p}_{v_j})\),Riemannian SGD 优化
- Step III: 无额外训练参数,直接通过映射转换
实验关键数据¶
主实验¶
在单细胞 scRNA-seq 数据集上评估基因扰动预测(GPP)任务,使用 scBERT 和 scGPT 两种 Transformer 基础模型:
| 模型 | 位置编码 | 单基因扰动 MSE | 双基因扰动 MSE |
|---|---|---|---|
| scBERT | 静态绝对PE (默认) | 0.224 | 0.230 |
| scBERT | 可训练相对PE | 0.219 (-0.005) | 0.215 (-0.015) |
| scBERT | CAPE | 0.193 (-0.031) | 0.189 (-0.041) |
| scGPT | 可训练绝对PE (默认) | 0.202 | 0.201 |
| scGPT | 可训练相对PE | 0.195 (-0.007) | 0.204 (+0.003) |
| scGPT | CAPE | 0.182 (-0.020) | 0.176 (-0.025) |
CAPE 平均降低 MSE 11.1%,而因果无关的相对PE仅降低 2.7%。
消融实验¶
以 scGPT 为骨干网络:
| 配置 | 单基因扰动 MSE | 双基因扰动 MSE |
|---|---|---|
| CAPE (完整) | 0.182 (±0.005) | 0.176 (±0.008) |
| CAPE-null (无PE) | 0.234 (±0.014) | 0.238 (±0.017) |
| CAPE-w/o-CSL (无因果学习) | 0.209 (±0.010) | 0.213 (±0.011) |
| CAPE-w/o-hyperbolic (欧氏替代) | 0.192 (±0.008) | 0.196 (±0.008) |
| CAPE-w/o-rotary (加性PE) | 0.201 (±0.009) | 0.208 (±0.010) |
关键发现¶
- 因果结构学习(CSL)是最关键组件,去掉后双基因MSE从 0.176 升至 0.213(+21%)
- 旋转形式也很重要,加性PE相比旋转PE性能差距明显
- 双曲空间建模的贡献相对温和但仍然正向,说明曲率感知优化确实有助于更好反映因果图结构
- 在合成数据上验证了理论属性:注意力随因果距离衰减、随因果通用性衰减、对位置扰动鲁棒
亮点与洞察¶
- 因果图 → 双曲空间 → RoPE 的统一框架极为优雅,将因果发现、双曲几何和旋转注意力三个不同领域的方法有机结合。因果距离直接映射为注意力衰减这个设计非常自然
- 理论分析扎实:证明了三个关键性质(因果距离衰减、因果通用性衰减、鲁棒性),为方法的有效性提供了数学保证
- 可迁移的设计思路:任何涉及非序列特征(如推荐系统的特征交互、知识图谱节点、分子图中的原子)的 Transformer 应用都可以借鉴这种"结构→双曲嵌入→旋转编码"的范式
局限性 / 可改进方向¶
- 因果结构学习部分假设了无环(DAG),对存在反馈环的因果系统不适用
- 当前仅在生物组学数据上验证,缺少在经济、社会科学等其他非序列因果数据上的评估
- 因果学习的准确性受数据量和噪声影响较大,在小样本高维场景下DAG学习可能不可靠
- 计算复杂度:因果学习 + 双曲嵌入 + Transformer 三阶段训练可能较慢,但论文在附录中指出复杂度分析显示可接受
相关工作与启发¶
- vs RoPE/绝对PE等标准方法: 这些方法假设预定义顺序,CAPE 扩展到无序因果特征
- vs scGPT/scBERT 的默认PE: 它们用基于表达量排序或预训练表示作为伪PE,本质上忽略因果关系
- vs NOTEARS/DAG-GNN: CAPE 借鉴了这些因果发现方法,但创新在于将学到的DAG转化为位置编码而非终点
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ 因果图→双曲嵌入→RoPE的组合非常新颖,跨领域的方法融合独具匠心
- 实验充分度: ⭐⭐⭐⭐ 合成+多组学实验全面,但领域覆盖面可更广
- 写作质量: ⭐⭐⭐⭐⭐ 数学推导严谨,叙述清晰,图示直观
- 价值: ⭐⭐⭐⭐ 为非序列数据的Transformer建模开辟了新方向,但实际应用场景目前偏窄