跳转至

Retrieval-Augmented Generation for Predicting Cellular Responses to Gene Perturbation

日期: 2026-03-07
arXiv: 2603.07233
代码: GitHub
领域: LLM/NLP
关键词: Retrieval-Augmented Generation, 单细胞扰动预测, 可微检索, Gumbel-Softmax, 细胞类型感知

一句话总结

提出 PT-RAG,首个将 RAG 范式引入单细胞基因扰动响应预测的框架,通过两阶段可微检索(语义检索 + Gumbel-Softmax 细胞类型感知选择)来为生成器提供相关扰动上下文,显著优于无检索和朴素 RAG 基线。

研究背景与动机

  1. 领域现状: 高通量 Perturb-seq 技术可以测量数千种基因扰动的单细胞转录组响应,但扰动与细胞类型的组合爆炸使得全面实验表征不可行,因此需要计算方法进行 in silico 预测。现有深度学习方法(scGen、CPA、GEARS、CellOT、STATE 等)已在扰动响应预测方面取得进展。
  2. 现有痛点: 现有方法仅基于控制细胞状态和扰动标识进行预测,不利用相关扰动的已知响应信息。这在预测新细胞类型中的扰动响应时尤为致命——模型缺乏该细胞类型对相关扰动的任何先验知识。
  3. 核心矛盾: RAG 在 NLP 中成功依赖现成的文本检索器和定义良好的相似度指标,但在细胞生物学中:(1) 不存在预训练的扰动检索器,(2) 基因间无公认的相似度度量,(3) 生成器需要输出高维细胞分布而非文本。朴素地套用 RAG 会适得其反。
  4. 切入角度: 功能相似的基因扰动应当诱发相似的细胞响应,因此用相关扰动的已观察响应来增强生成器上下文,应能改善预测。但关键是检索目标本身需要学习——什么上下文对当前细胞类型有帮助并非先验已知。
  5. 核心 idea: 设计两阶段可微检索:第一阶段用 GenePT embedding 的余弦相似度快速缩小候选池(从约 2009 个扰动到 K 个);第二阶段用 Gumbel-Softmax 进行条件化于细胞状态和扰动的可微离散选择,实现端到端优化检索与生成的联合训练。

方法详解

整体框架

PT-RAG 建立在 STATE 框架之上。输入为一组控制细胞的基因表达谱 \(\{x_i^{ctrl}\}_{i=1}^N \subseteq \mathbb{R}^G\)(G=2000 个高变异基因)和扰动标识 \(p^{pert}\),目标是预测扰动后的细胞群分布 \(\{\hat{x}_i^{pert}\}_{i=1}^N\)。框架包含三个核心模块:

  • 冻结的 Cell Encoder:预训练模型(SE-600M),将基因表达谱映射到 128 维隐表示 \(h^{ctrl}\)
  • 可训练的 Perturbation Encoder:单层 MLP,将 GenePT embedding(1536 维)映射到 128 维
  • 两阶段检索模块 + Transformer Generator(Llama backbone,序列长度=64 细胞,batch=64)

论文对比了三种架构:(1) Generation baseline(无检索),(2) Vanilla RAG(非可微检索 + Cross-Attention),(3) PT-RAG(两阶段可微检索)。

关键设计

  1. GenePT 扰动表示: 不同于以往工作使用 one-hot 编码扰动,PT-RAG 利用 GenePT(基于 GPT-3.5 编码的 NCBI 基因描述生成的 embedding)来表示基因。这使得语义相似的基因在 embedding 空间中接近,为检索提供了有意义的相似度计算基础。构建扰动数据库 \(\mathcal{D} = \{h_p^{gpt}; \forall p \in \mathcal{P}\}\),约 2009 个扰动。

  2. 第一阶段:语义检索(非可微): 通过 GenePT embedding 的余弦相似度,从 ~2009 个扰动中检索 Top-K(K=32)个语义最相近的候选扰动。这一步高效剪枝了搜索空间,但不区分细胞类型(与 Vanilla RAG 相同)。

  3. 第二阶段:可微检索(核心创新): 对 K 个候选扰动,构造三元组 \(c_k = [h^{ctrl}; h_{pert}; h_k^{cxt}]\)(384 维),同时编码了细胞状态、目标扰动和候选上下文的关系。通过评分 MLP + LayerNorm 生成二元 logits(include/exclude),再应用 Straight-Through Gumbel-Softmax 估计器进行硬二值决策(前向传播为离散 0/1,反向传播通过软概率传梯度)。

这一设计的关键优势: - 细胞类型感知:选择决策同时依赖于 \(h^{ctrl}\)(细胞状态),同一基因在不同细胞类型中会检索到不同扰动 - 端到端可微:检索目标与生成目标联合优化 - 离散选择\(w_k \in \{0,1\}\),真正地选择/排除候选扰动

  1. 上下文聚合与生成: 被选中的三元组经 Projection MLP 投影后加权求和 \(z = \sum_{k=1}^K w_k \cdot h_k'\),送入 Transformer Generator 生成扰动后的细胞群。

损失函数 / 训练策略

\[\mathcal{L} = \mathcal{L}_{dist} + \lambda_{sparse} \mathcal{L}_{sparse}\]
  • 分布损失 \(\mathcal{L}_{dist}\):Energy Distance,衡量预测与真实扰动细胞群在分布层面的差异
  • 稀疏损失 \(\mathcal{L}_{sparse} = \frac{1}{K}\sum_{k=1}^K w_k\):L1 惩罚,防止模式坍塌(选择所有候选扰动),\(\lambda_{sparse} = 0.1\)
  • 训练配置:Adam 优化器,lr=\(10^{-3}\),weight decay=0.0005,最大 50000 steps,每 2000 steps 验证
  • Gumbel-Softmax 温度 \(\tau = 0.5\)

消融实验表明:\(\lambda_{sparse}=0\) 时模型平均检索 31.9/32 个扰动,性能严重退化(Pearson 仅 0.134);\(\lambda_{sparse} \in \{0.01, 0.10, 1.00\}\) 时性能稳定且优异,说明稀疏约束必要但不敏感。

实验关键数据

数据集与评估协议

  • Replogle-Nadig 数据集:2009 个单基因扰动,四种细胞类型(K562 慢性粒细胞白血病、Jurkat T 细胞、RPE1 视网膜色素上皮、HepG2 肝癌)
  • 跨细胞类型泛化:对每个目标细胞类型,用其余三种训练,提供目标类型 30% 扰动作为 few-shot,其余 70% 用于验证和测试
  • 共 1635 个测试扰动(HepG2 375, RPE1 416, Jurkat 443, K562 401)

主实验

指标 STATE STATE+GenePT Vanilla RAG PT-RAG
Pearson DEG ↑ 0.624 0.631 0.396 0.633
Spearman DEG ↑ 0.403 0.411 0.307 0.412
MSE ↓ 0.211 0.210 0.316 0.210
RMSE ↓ 0.458 0.458 0.562 0.457
MAE ↓ 0.298 0.296 0.429 0.295
MSE_PCA50 ↓ 8.43 8.42 12.64 8.39
W₁ ↓ 35.70 35.53 48.48 35.41
W₂ ↓ 646.1 638.7 1189.5 633.7
Energy ↓ 9.41 9.40 14.18 9.33

PT-RAG 在所有 9 个指标上取得最优(或并列最优),对 STATE 的改进在 Pearson、Spearman、MAE、W₁、W₂ 上均达到统计显著(FDR 校正 p<0.01)。Vanilla RAG 全面崩溃——这本身就是重要发现。

消融实验

消融设置 Pearson DEG ↑ Spearman DEG ↑ W₂ ↓ Energy ↓ 平均检索数
PT-RAG (λ=0, 无稀疏) 0.134 -0.025 651.7 11.97 31.95
PT-RAG (λ=0.01) 0.594 0.386 633.8 9.48 12.54
PT-RAG (λ=0.10) 0.604 0.401 631.8 9.48 6.61
PT-RAG (λ=1.00) 0.598 0.394 637.3 9.52 4.92
Vanilla RAG (K=2) 0.293 0.220 2
Vanilla RAG (K=32) 0.351 0.289 1198.5 14.25 32

消融实验清晰表明:(1) 无稀疏约束导致检索所有候选,性能崩溃;(2) 非零稀疏约束下性能稳健不敏感;(3) K=16 与 K=32 对 PT-RAG 影响极小;(4) 增大 Vanilla RAG 的 K 虽有微小改善,但远不及 PT-RAG。

关键发现

  1. Vanilla RAG 的戏剧性失败是核心发现:朴素 RAG 不仅没帮助,反而严重伤害性能(W₂: 1189.5 vs STATE 的 646.1),证明在缺乏定义良好的相似度指标的领域中,非可微、细胞类型无关的检索会引入噪声。
  2. 细胞类型特异性检索模式:对 33 个跨所有细胞类型共有的测试基因,PT-RAG 在不同细胞类型间的 Top-10 检索重叠仅约 19%(Jaccard 相似度 0.185-0.196),证实模型确实学会了细胞类型感知的检索。
  3. 生物学可解释性:以 WARS 基因为例,PT-RAG 在所有细胞类型中都检索到氨基酰-tRNA 合成酶家族成员(保持功能一致性),但具体选择的成员因细胞类型而异,符合不同细胞类型对 tRNA 充电通路的差异化依赖。
  4. 计算开销可控:PT-RAG 约 21M 参数,FLOPs/batch 为 2.86B(约为 baseline 的 1.7×),在 A100 上每种目标细胞类型训练约 8-10 小时。

亮点与洞察

  • RAG 范式的非文本领域推广:首次证明 RAG 可以超越语言模型,应用于细胞响应生成这种检索对象和生成输出都非文本的场景。关键洞察是:在没有预定义相似度指标的领域,检索目标本身必须被学习
  • 负面结果同样重要:Vanilla RAG 的失败不是次要发现,而是论文的核心贡献之一——它为"何时需要可微检索"提供了实证答案。
  • 三元组评分设计精巧:将 \([h^{ctrl}; h_{pert}; h_k^{cxt}]\) 拼接后评分,使选择同时依赖细胞状态、目标扰动和候选上下文,这种设计使同一查询基因在不同细胞背景下获得不同的检索结果。
  • Gumbel-Softmax 的实际应用范例:展示了如何用 Straight-Through Gumbel-Softmax 实现离散检索决策的端到端训练,对其他需要可微离散选择的场景有参考价值。

局限性 / 可改进方向

  1. 改进幅度相对有限:相比 STATE+GenePT,PT-RAG 的改进主要集中在 Wasserstein 距离上,基因级相关性和重构精度的提升较为温和。在 K562 细胞系上 STATE 甚至优于 PT-RAG。
  2. 仅限单基因扰动:未扩展到组合扰动(多基因同时敲除/激活),这在真实药物开发场景中更常见。
  3. 计算开销 1.7×:虽然绝对量可控,但在大规模部署时仍需考虑。
  4. 检索池受限:仅从训练集中约 2009 个扰动中检索,未探索更大规模的外部知识库。
  5. 未来方向:GraphRAG(利用基因调控网络结构)、多模态检索(序列 + 结构 + 功能注释结合)、化合物/CRISPR 激活/干扰的扩展。

相关工作与启发

  • STATE (Adduri et al., 2025):PT-RAG 的 backbone,将细胞群建模为序列用 Transformer 处理,使用分布损失训练。PT-RAG 在其基础上加入了检索增强。
  • Differentiable RAG (Zamani & Bendersky 2024; Gao et al. 2025):Stochastic RAG 和 D-RAG 在文本域展示了端到端优化检索的价值,PT-RAG 首次将此思路迁移到非文本域。
  • GenePT (Chen & Zou, 2023):基于 GPT-3.5 编码的基因功能描述提供基因 embedding,为 PT-RAG 的第一阶段检索提供语义基础。
  • 启发:这种"两阶段检索"的设计模式——先用低成本的固定检索缩小候选集,再用可微机制精选——可推广到其他缺乏现成相似度指标的领域(如分子设计、材料科学等)。

评分

⭐⭐⭐⭐ (4/5)

优点:问题提出深刻(RAG 在非文本域的适用性与挑战),方法设计合理(两阶段可微检索),实验全面(统计检验严格,消融丰富),负面结果(Vanilla RAG 失败)本身极有说明力,生物学可解释性分析(Jaccard 重叠、细胞类型特异检索模式)增加了说服力。

不足:相对 STATE+GenePT 的改进幅度有限,在 K562 上甚至略逊;仅限单基因扰动场景;计算开销增加但整体性能提升温和。整体是一篇扎实的方法论工作,在"RAG 超越语言模型"这一方向上迈出了有意义的一步。