Retrieval-Augmented Generation for Predicting Cellular Responses to Gene Perturbation¶
日期: 2026-03-07
arXiv: 2603.07233
代码: GitHub
领域: LLM/NLP
关键词: Retrieval-Augmented Generation, 单细胞扰动预测, 可微检索, Gumbel-Softmax, 细胞类型感知
一句话总结¶
提出 PT-RAG,首个将 RAG 范式引入单细胞基因扰动响应预测的框架,通过两阶段可微检索(语义检索 + Gumbel-Softmax 细胞类型感知选择)来为生成器提供相关扰动上下文,显著优于无检索和朴素 RAG 基线。
研究背景与动机¶
- 领域现状: 高通量 Perturb-seq 技术可以测量数千种基因扰动的单细胞转录组响应,但扰动与细胞类型的组合爆炸使得全面实验表征不可行,因此需要计算方法进行 in silico 预测。现有深度学习方法(scGen、CPA、GEARS、CellOT、STATE 等)已在扰动响应预测方面取得进展。
- 现有痛点: 现有方法仅基于控制细胞状态和扰动标识进行预测,不利用相关扰动的已知响应信息。这在预测新细胞类型中的扰动响应时尤为致命——模型缺乏该细胞类型对相关扰动的任何先验知识。
- 核心矛盾: RAG 在 NLP 中成功依赖现成的文本检索器和定义良好的相似度指标,但在细胞生物学中:(1) 不存在预训练的扰动检索器,(2) 基因间无公认的相似度度量,(3) 生成器需要输出高维细胞分布而非文本。朴素地套用 RAG 会适得其反。
- 切入角度: 功能相似的基因扰动应当诱发相似的细胞响应,因此用相关扰动的已观察响应来增强生成器上下文,应能改善预测。但关键是检索目标本身需要学习——什么上下文对当前细胞类型有帮助并非先验已知。
- 核心 idea: 设计两阶段可微检索:第一阶段用 GenePT embedding 的余弦相似度快速缩小候选池(从约 2009 个扰动到 K 个);第二阶段用 Gumbel-Softmax 进行条件化于细胞状态和扰动的可微离散选择,实现端到端优化检索与生成的联合训练。
方法详解¶
整体框架¶
PT-RAG 建立在 STATE 框架之上。输入为一组控制细胞的基因表达谱 \(\{x_i^{ctrl}\}_{i=1}^N \subseteq \mathbb{R}^G\)(G=2000 个高变异基因)和扰动标识 \(p^{pert}\),目标是预测扰动后的细胞群分布 \(\{\hat{x}_i^{pert}\}_{i=1}^N\)。框架包含三个核心模块:
- 冻结的 Cell Encoder:预训练模型(SE-600M),将基因表达谱映射到 128 维隐表示 \(h^{ctrl}\)
- 可训练的 Perturbation Encoder:单层 MLP,将 GenePT embedding(1536 维)映射到 128 维
- 两阶段检索模块 + Transformer Generator(Llama backbone,序列长度=64 细胞,batch=64)
论文对比了三种架构:(1) Generation baseline(无检索),(2) Vanilla RAG(非可微检索 + Cross-Attention),(3) PT-RAG(两阶段可微检索)。
关键设计¶
-
GenePT 扰动表示: 不同于以往工作使用 one-hot 编码扰动,PT-RAG 利用 GenePT(基于 GPT-3.5 编码的 NCBI 基因描述生成的 embedding)来表示基因。这使得语义相似的基因在 embedding 空间中接近,为检索提供了有意义的相似度计算基础。构建扰动数据库 \(\mathcal{D} = \{h_p^{gpt}; \forall p \in \mathcal{P}\}\),约 2009 个扰动。
-
第一阶段:语义检索(非可微): 通过 GenePT embedding 的余弦相似度,从 ~2009 个扰动中检索 Top-K(K=32)个语义最相近的候选扰动。这一步高效剪枝了搜索空间,但不区分细胞类型(与 Vanilla RAG 相同)。
-
第二阶段:可微检索(核心创新): 对 K 个候选扰动,构造三元组 \(c_k = [h^{ctrl}; h_{pert}; h_k^{cxt}]\)(384 维),同时编码了细胞状态、目标扰动和候选上下文的关系。通过评分 MLP + LayerNorm 生成二元 logits(include/exclude),再应用 Straight-Through Gumbel-Softmax 估计器进行硬二值决策(前向传播为离散 0/1,反向传播通过软概率传梯度)。
这一设计的关键优势: - 细胞类型感知:选择决策同时依赖于 \(h^{ctrl}\)(细胞状态),同一基因在不同细胞类型中会检索到不同扰动 - 端到端可微:检索目标与生成目标联合优化 - 离散选择:\(w_k \in \{0,1\}\),真正地选择/排除候选扰动
- 上下文聚合与生成: 被选中的三元组经 Projection MLP 投影后加权求和 \(z = \sum_{k=1}^K w_k \cdot h_k'\),送入 Transformer Generator 生成扰动后的细胞群。
损失函数 / 训练策略¶
- 分布损失 \(\mathcal{L}_{dist}\):Energy Distance,衡量预测与真实扰动细胞群在分布层面的差异
- 稀疏损失 \(\mathcal{L}_{sparse} = \frac{1}{K}\sum_{k=1}^K w_k\):L1 惩罚,防止模式坍塌(选择所有候选扰动),\(\lambda_{sparse} = 0.1\)
- 训练配置:Adam 优化器,lr=\(10^{-3}\),weight decay=0.0005,最大 50000 steps,每 2000 steps 验证
- Gumbel-Softmax 温度 \(\tau = 0.5\)
消融实验表明:\(\lambda_{sparse}=0\) 时模型平均检索 31.9/32 个扰动,性能严重退化(Pearson 仅 0.134);\(\lambda_{sparse} \in \{0.01, 0.10, 1.00\}\) 时性能稳定且优异,说明稀疏约束必要但不敏感。
实验关键数据¶
数据集与评估协议¶
- Replogle-Nadig 数据集:2009 个单基因扰动,四种细胞类型(K562 慢性粒细胞白血病、Jurkat T 细胞、RPE1 视网膜色素上皮、HepG2 肝癌)
- 跨细胞类型泛化:对每个目标细胞类型,用其余三种训练,提供目标类型 30% 扰动作为 few-shot,其余 70% 用于验证和测试
- 共 1635 个测试扰动(HepG2 375, RPE1 416, Jurkat 443, K562 401)
主实验¶
| 指标 | STATE | STATE+GenePT | Vanilla RAG | PT-RAG |
|---|---|---|---|---|
| Pearson DEG ↑ | 0.624 | 0.631 | 0.396 | 0.633 |
| Spearman DEG ↑ | 0.403 | 0.411 | 0.307 | 0.412 |
| MSE ↓ | 0.211 | 0.210 | 0.316 | 0.210 |
| RMSE ↓ | 0.458 | 0.458 | 0.562 | 0.457 |
| MAE ↓ | 0.298 | 0.296 | 0.429 | 0.295 |
| MSE_PCA50 ↓ | 8.43 | 8.42 | 12.64 | 8.39 |
| W₁ ↓ | 35.70 | 35.53 | 48.48 | 35.41 |
| W₂ ↓ | 646.1 | 638.7 | 1189.5 | 633.7 |
| Energy ↓ | 9.41 | 9.40 | 14.18 | 9.33 |
PT-RAG 在所有 9 个指标上取得最优(或并列最优),对 STATE 的改进在 Pearson、Spearman、MAE、W₁、W₂ 上均达到统计显著(FDR 校正 p<0.01)。Vanilla RAG 全面崩溃——这本身就是重要发现。
消融实验¶
| 消融设置 | Pearson DEG ↑ | Spearman DEG ↑ | W₂ ↓ | Energy ↓ | 平均检索数 |
|---|---|---|---|---|---|
| PT-RAG (λ=0, 无稀疏) | 0.134 | -0.025 | 651.7 | 11.97 | 31.95 |
| PT-RAG (λ=0.01) | 0.594 | 0.386 | 633.8 | 9.48 | 12.54 |
| PT-RAG (λ=0.10) | 0.604 | 0.401 | 631.8 | 9.48 | 6.61 |
| PT-RAG (λ=1.00) | 0.598 | 0.394 | 637.3 | 9.52 | 4.92 |
| Vanilla RAG (K=2) | 0.293 | 0.220 | — | — | 2 |
| Vanilla RAG (K=32) | 0.351 | 0.289 | 1198.5 | 14.25 | 32 |
消融实验清晰表明:(1) 无稀疏约束导致检索所有候选,性能崩溃;(2) 非零稀疏约束下性能稳健不敏感;(3) K=16 与 K=32 对 PT-RAG 影响极小;(4) 增大 Vanilla RAG 的 K 虽有微小改善,但远不及 PT-RAG。
关键发现¶
- Vanilla RAG 的戏剧性失败是核心发现:朴素 RAG 不仅没帮助,反而严重伤害性能(W₂: 1189.5 vs STATE 的 646.1),证明在缺乏定义良好的相似度指标的领域中,非可微、细胞类型无关的检索会引入噪声。
- 细胞类型特异性检索模式:对 33 个跨所有细胞类型共有的测试基因,PT-RAG 在不同细胞类型间的 Top-10 检索重叠仅约 19%(Jaccard 相似度 0.185-0.196),证实模型确实学会了细胞类型感知的检索。
- 生物学可解释性:以 WARS 基因为例,PT-RAG 在所有细胞类型中都检索到氨基酰-tRNA 合成酶家族成员(保持功能一致性),但具体选择的成员因细胞类型而异,符合不同细胞类型对 tRNA 充电通路的差异化依赖。
- 计算开销可控:PT-RAG 约 21M 参数,FLOPs/batch 为 2.86B(约为 baseline 的 1.7×),在 A100 上每种目标细胞类型训练约 8-10 小时。
亮点与洞察¶
- RAG 范式的非文本领域推广:首次证明 RAG 可以超越语言模型,应用于细胞响应生成这种检索对象和生成输出都非文本的场景。关键洞察是:在没有预定义相似度指标的领域,检索目标本身必须被学习。
- 负面结果同样重要:Vanilla RAG 的失败不是次要发现,而是论文的核心贡献之一——它为"何时需要可微检索"提供了实证答案。
- 三元组评分设计精巧:将 \([h^{ctrl}; h_{pert}; h_k^{cxt}]\) 拼接后评分,使选择同时依赖细胞状态、目标扰动和候选上下文,这种设计使同一查询基因在不同细胞背景下获得不同的检索结果。
- Gumbel-Softmax 的实际应用范例:展示了如何用 Straight-Through Gumbel-Softmax 实现离散检索决策的端到端训练,对其他需要可微离散选择的场景有参考价值。
局限性 / 可改进方向¶
- 改进幅度相对有限:相比 STATE+GenePT,PT-RAG 的改进主要集中在 Wasserstein 距离上,基因级相关性和重构精度的提升较为温和。在 K562 细胞系上 STATE 甚至优于 PT-RAG。
- 仅限单基因扰动:未扩展到组合扰动(多基因同时敲除/激活),这在真实药物开发场景中更常见。
- 计算开销 1.7×:虽然绝对量可控,但在大规模部署时仍需考虑。
- 检索池受限:仅从训练集中约 2009 个扰动中检索,未探索更大规模的外部知识库。
- 未来方向:GraphRAG(利用基因调控网络结构)、多模态检索(序列 + 结构 + 功能注释结合)、化合物/CRISPR 激活/干扰的扩展。
相关工作与启发¶
- STATE (Adduri et al., 2025):PT-RAG 的 backbone,将细胞群建模为序列用 Transformer 处理,使用分布损失训练。PT-RAG 在其基础上加入了检索增强。
- Differentiable RAG (Zamani & Bendersky 2024; Gao et al. 2025):Stochastic RAG 和 D-RAG 在文本域展示了端到端优化检索的价值,PT-RAG 首次将此思路迁移到非文本域。
- GenePT (Chen & Zou, 2023):基于 GPT-3.5 编码的基因功能描述提供基因 embedding,为 PT-RAG 的第一阶段检索提供语义基础。
- 启发:这种"两阶段检索"的设计模式——先用低成本的固定检索缩小候选集,再用可微机制精选——可推广到其他缺乏现成相似度指标的领域(如分子设计、材料科学等)。
评分¶
⭐⭐⭐⭐ (4/5)
优点:问题提出深刻(RAG 在非文本域的适用性与挑战),方法设计合理(两阶段可微检索),实验全面(统计检验严格,消融丰富),负面结果(Vanilla RAG 失败)本身极有说明力,生物学可解释性分析(Jaccard 重叠、细胞类型特异检索模式)增加了说服力。
不足:相对 STATE+GenePT 的改进幅度有限,在 K562 上甚至略逊;仅限单基因扰动场景;计算开销增加但整体性能提升温和。整体是一篇扎实的方法论工作,在"RAG 超越语言模型"这一方向上迈出了有意义的一步。