跳转至

How Do Transformers Learn Implicit Reasoning?

会议: NeurIPS 2025
arXiv: 2505.23653
代码: GitHub
领域: 可解释性 / 多跳推理机制
关键词: implicit reasoning, multi-hop, grokking, cosine clustering, semantic patching, Transformer internals

一句话总结

在精细控制的符号环境中从零训练 Transformer,发现多跳隐式推理经历"记忆→分布内泛化→跨分布泛化"三个阶段,核心机制不是中间实体的可解码性,而是其在余弦空间中的聚类一致性——同一中间实体的表示在不同查询中形成紧密聚类时,推理能力才涌现。

研究背景与动机

领域现状:LLM 不仅能通过 Chain-of-Thought 显式推理,还能进行隐式推理(不显式化中间步骤直接给出答案)。但隐式推理的内部机制——模型是真正进行了逐步推理还是仅靠记忆?——仍不清楚。

现有痛点:(1) 基于预训练 LLM 的分析无法精确控制训练数据,难以区分"真正推理"与"记忆/捷径";(2) 现有符号数据集(wang2024grokking 等)虽从零训练,但缺乏查询级精细控制和行为粒度,无法隔离泛化的具体条件;(3) 主流分析工具(logit lens 解码、因果修补)存在限制——logit lens 揭示相关性但非因果关系,因果修补测试影响但不测试语义内容。

核心矛盾:隐式推理的正确答案可能来自两种截然不同的认知过程(逐步推理 vs 记忆),现有方法无法可靠区分。

本文目标 在可控环境中揭示 Transformer 如何获得和执行隐式多跳推理——包括发展轨迹、必要条件和内部机制。

切入角度:构建扩展的符号环境支持查询级消融,引入两个新诊断工具(跨查询语义修补 + 余弦表示透镜),将行为观察与内部机制连接起来。

核心 idea:隐式推理的关键不是中间实体能否被"显式解码",而是其表示是否在余弦空间中形成一致性聚类——聚类结构的涌现与推理能力的涌现精确对应。

方法详解

整体框架

构建符号环境(2000 实体 × 200 关系 → 40000 原子三元组 + 2-hop 组合查询)→ 从零训练 GPT-2 → 行为分析(三阶段发展轨迹 + 消融研究)→ 机制分析(跨查询语义修补定位 + logit lens 可解码性 + 余弦聚类透镜)→ 闭环解释(用机制解释行为现象)

关键设计

  1. 精细控制的符号推理环境

    • 原子三元组 \((e_1, r_1) \to e_2\) 分为 ID(分布内)和 OOD(分布外),共享实体和关系
    • 2-hop 查询 \((e_1, r_1, r_2) \to e_3\):模型需隐式推理中间实体 \(e_2\)
    • 训练集 Train-II(两跳均为 ID),测试集 Test-II/Test-OI/Test-IO/Test-OO
    • 支持多种消融配置:移除特定三元组、限制组合角色、移除子集
    • 数据规模:2000 实体、200 关系、38000 ID 三元组、2000 OOD 三元组、273600 Train-II 查询
    • 设计动机:查询级控制允许精确隔离"哪些训练信号对解决特定查询是必要的"
  2. 跨查询语义修补(Cross-Query Semantic Patching)

    • 比标准因果修补更强:不仅测试隐藏状态的因果影响,还测试其语义内容
    • 方法:从源查询 \((e_1, r_1, r_2)\) 提取候选位置的隐藏向量,插入结构相似的目标查询 \((e_5, r_6, r_7)\) 的同一位置
    • 成功标准:如果修补后模型预测从 \(r_7(r_6(e_5))\) 变为 \(r_7(r_1(e_1))\),说明插入的表示携带了可迁移的中间实体语义信息
    • 结果:有效修补主要发生在 \(r_1\) token 位置的中间层(8 层 GPT-2 的第 5 层)
    • 设计动机:超越"相关性"(线性探测)和"表面因果性"(标准修补),测试表示的语义可迁移性
  3. 余弦聚类透镜(Cosine-Based Representational Lens)

    • 核心洞察:不问"能否解码这个隐藏状态",而问"这个表示在不同上下文中如何组织?"
    • 方法:对于共享同一中间实体 \(e_2\) 的所有查询,提取 \(\mathbf{h}_{r_1}^5\),计算两两余弦距离,用 MDS 投影可视化
    • 定义两个量化指标:
      • ID Cohesion Score:ID 派生表示与其质心的平均余弦相似度(分布内一致性)
      • OOD Alignment Score:OOD 派生表示与 ID 质心的平均余弦相似度(跨分布对齐)
    • 设计动机:logit lens 解码成功率与推理能力涌现不对应("可解码 ≠ 可推理"),需要新的表示分析视角

三阶段发展轨迹

  • Phase I 记忆:快速拟合训练数据(原子事实 + 2-hop 组合),但对未见查询不泛化
  • Phase II 分布内泛化:对未见 ID-ID 组合(Test-II)开始泛化,类似 grokking 现象
  • Phase III 跨分布泛化:逐步将 OOD 三元组纳入第一跳推理(Test-OI),但当第二跳为 OOD 时始终失败

闭环机制解释

行为现象可被余弦聚类机制完整解释。ID三元组训练通过自回归因果掩码共享 \(r_1\) 位置的前缀结构——\((e_1, r_1)\) 在原子查询和两跳查询中产生完全相同的隐藏状态——约束中间实体表示落入支持解码的子空间,加速了聚类收敛。OOD表示被频繁的OOD原子三元组训练逐渐拉入ID聚类,使第一跳看似泛化,但实际上是ID锚点效应的副产品而非真正的跨分布推理泛化。第二跳由于缺乏因果掩码位共享带来的锚点效应,必须依赖直接的查询级监督。3跳推理实验进一步验证:仅Test-OII(OOD→ID→ID)成功,所有OOD出现在后续跳的配置均失败。实验还发现模型最初尝试显式解码中间实体但很快放弃这一策略,转而采用非显式但几何一致的内部表示方案。

核心发现

  • ID 三元组非必要但加速泛化:仅用 Train-II(无原子三元组)仍可泛化到 Test-II,但加入 ID 三元组显著加速(因为共享 \(r_1\) 位置的隐藏状态,约束表示空间)
  • 第二跳泛化需查询级匹配:必须在训练中遇到特定的第二跳组合才能泛化;且第二跳暴露频率越高,对应查询被正确回答的时间越早
  • 可解码性 ≠ 可推理性:ID 派生的中间实体表示在 Phase I 就有 97%+ 的解码成功率,但推理能力到 Phase II 才涌现;ID 和 OOD 表示解码成功率无显著差异,但推理表现差异巨大
  • 第一跳OOD泛化是伪泛化:本质上是ID锚点效应的副产品,移除ID三元组后OOD推理完全失败

实验关键数据

中间实体解码成功率 vs 推理阶段

来源 Immediate Probing Full-run Probing
Phase I Phase II Phase III Phase I Phase II Phase III
ID-derived 92.1% 98.8% 99.9% 97.1% 99.9% 99.9%
OOD-derived 67.7% 81.3% 99.8% 83.7% 98.6% 99.7%

行为消融

配置 Test-II 泛化 Test-OI 泛化 备注
Base(全部数据) ✅ Phase II ✅ Phase III 完整三阶段
仅 Train-II(无原子三元组) ✅(延迟) ID 三元组加速但非必要
移除特定第二跳原子三元组 ✅(其他查询) - 对应查询始终失败
ID/OOD=0.3/0.7 ID 主导地位对 Phase III 关键
ID/OOD=0.8/0.2 充足 ID 暴露才能驱动 OOD 对齐

余弦聚类与推理能力的对应

  • Phase I:ID 和 OOD 表示在余弦空间中散乱分布
  • Phase II:ID 表示形成紧密聚类(ID Cohesion Score 上升),Test-II 泛化同步涌现
  • Phase III:OOD 表示开始向 ID 聚类靠拢(OOD Alignment Score 上升),Test-OI 泛化同步涌现

亮点与洞察

  • "可解码 ≠ 可推理"是核心洞察——推翻了大量先前工作的隐含假设(即如果 logit lens 能解码中间实体,模型就在做推理)
  • 跨查询语义修补比标准因果修补更强:测试的是表示的语义可迁移性而非仅因果影响
  • 余弦聚类透镜提供了推理能力涌现的几何解释:中间实体需在表示空间中形成"抽象"(不同上下文的相同实体映射到相近向量)才能被复用
  • 第一跳 OOD 泛化的"虚假"本质——不是真正的泛化,而是 ID 监督驱动的表示对齐的附带效应——这一发现对理解 LLM 的泛化能力边界有深远意义
  • 第二跳泛化的硬性要求(查询级匹配)解释了为何单跳知识不能自动迁移到多跳推理

局限与展望

  • 实验在受控符号环境中进行(2000 实体 × 200 关系),与真实 LLM 的复杂知识库存在显著差距——LLM 的知识交互机制可能与符号环境中观察到的不同,发现应被视为"初步指导"而非完整解释
  • 主要使用 GPT-2(8 层),虽在 Qwen2.5-1.5B 上验证了基本一致性但观察到 Phase III 不稳定和 ID Cohesion 与 Test-II 准确率的脱钩——更大模型中聚类可能只是推理的必要条件而非充分条件
  • 余弦聚类现象的因果关系尚未严格建立——聚类与推理能力高度相关但聚类是否是推理的充分条件未证明,需要干预实验来建立因果方向
  • 仅考虑 2-hop(辅以 3-hop 验证),更长推理链(5-hop+)上的发展轨迹和机制是否保持一致有待探索
  • 未连接到实际 LLM 的知识编辑或 CoT 推理优化等实际应用场景,发现的实践指导价值有待挖掘

相关工作与启发

  • vs wang2024grokking:本文在其符号数据集基础上大幅扩展——增加 OOD 分区、查询级消融、更精细的行为分辨率。发现了 grokking 之外的 Phase III 跨分布泛化,且证明原子三元组并非泛化的必要条件
  • vs hopping-too-late / yang2024large2:这些工作基于预训练 LLM 研究隐式推理失败,但无法控制训练数据。本文从零训练提供了更干净的因果分析,且用查询级匹配需求解释了它们观察到的多跳推理失败现象
  • vs logit lens 系列工作(sakarvadia2024towards, li2024understanding, zhang2024locate):这些工作假设"中间实体可解码=模型在做推理",本文用三个阶段的解码-推理脱钩现象推翻了这一假设,提供了余弦聚类透镜作为替代性解释框架
  • vs Balesni et al. (2024):他们发现同一段落中的知识更容易被组合推理,本文从查询级匹配角度给出了更精确的机制解释——关键不在于知识的物理邻近性,而在于组合结构的直接暴露
  • 对 Latent CoT 的启发:结果表明隐式推理不需要中间步骤在 token 空间可解码,支持了在潜在空间进行推理(如 continuous thoughts)的研究方向
  • 方法论启发:跨查询语义修补可推广到其他需要测试表示语义内容(而非仅因果影响)的场景,例如知识编辑中测试编辑后的表示是否正确传播到下游推理

评分

  • 新颖性: ⭐⭐⭐⭐⭐ "可解码≠可推理"的洞察和余弦聚类透镜都是重要的原创贡献
  • 实验充分度: ⭐⭐⭐⭐ 行为消融+机制分析+闭环解释的三层结构非常完整
  • 写作质量: ⭐⭐⭐⭐⭐ 论文逻辑链极其清晰:行为观察→机制分析→闭环解释,环环相扣
  • 价值: ⭐⭐⭐⭐ 对理解 LLM 推理机制提供了深刻洞察,但从符号环境到真实 LLM 的迁移仍需验证

相关论文