跳转至

SoftCoT: Soft Chain-of-Thought for Efficient Reasoning with LLMs

会议: ACL 2025
arXiv: 2502.12134
代码: https://github.com/xuyige/SoftCoT
领域: LLM效率
关键词: 连续空间推理, 软思维token, 辅助小模型, 投影模块, 灾难性遗忘缓解

一句话总结

提出 SoftCoT,用一个冻结的小型辅助模型(如 LLaMA-3.2-1B)生成实例特定的"软思维 token"(连续隐状态),通过可训练的投影模块映射到主 LLM 的表示空间作为推理前缀,实现参数高效的连续空间 CoT 推理,避免了全模型微调导致的灾难性遗忘问题。

研究背景与动机

  1. 领域现状:Chain-of-Thought (CoT) 推理通过生成中间推理步骤提升 LLM 在复杂任务上的表现。传统 CoT 在离散 token 空间生成推理链,受限于词表空间。近期有 Coconut、CCoT 等方法探索连续空间推理,用潜在表示替代离散 token 序列。
  2. 现有痛点
  3. Coconut/CCoT 需要全模型微调(语言建模目标),在 GPT-2 上有效,但在 LLaMA-3.1-8B-Instruct 等已对齐的强模型上会导致灾难性遗忘——LoRA 微调后的表现甚至低于 zero-shot CoT
  4. 离散 CoT 的 token 空间受限,可能不是最优的推理表示
  5. 多路径采样(self-consistency、Tree-of-Thought)计算开销大
  6. 核心矛盾:连续空间推理有表达力优势,但现有方法需要修改 LLM 参数,导致强大的指令调优模型丧失已学到的推理能力。如何在不修改 LLM 的前提下引入连续空间推理?
  7. 本文要解决什么?
  8. 如何在冻结主 LLM 的条件下实现连续空间 CoT 推理
  9. 如何弥合辅助模型与主 LLM 之间的表示空间差异
  10. 如何通过参数高效训练获得优于 zero-shot CoT 的性能
  11. 切入角度:受 prompt tuning 和 speculative decoding 的启发——用一个小型冻结辅助模型生成与实例相关的软提示(连续思维 token),代替 CoT 中的离散推理前缀。主 LLM 完全冻结,只训练一个投影模块。
  12. 核心 idea 一句话:不修改 LLM,而是用辅助小模型的连续隐状态作为实例自适应的"软思维"前缀,通过投影层送入冻结的主 LLM 来增强推理。

方法详解

整体框架

给定问题 \(\mathcal{Q}\)辅助模型(冻结的 LLaMA-3.2-1B)处理 [指令, 问题, N个UNK token],提取最后 N 个位置的最后一层隐状态作为软思维 → 投影模块(可训练线性层)将软思维从辅助模型维度映射到主 LLM 维度 → 主 LLM(冻结的 LLaMA-3.1-8B)接收 [指令, 问题, 软思维],自回归生成推理步骤和答案

关键设计

  1. 软思维 token 生成:
  2. 做什么:用辅助小模型为每个问题生成实例特定的连续推理前缀
  3. 核心思路:
    • 输入构造:\(\mathbf{x}_{\text{assist}} = \text{concat}[\mathcal{I}_{\text{assist}}, \mathcal{Q}, \text{[UNK]}_{1:N}]\)
    • 辅助模型前向传播后,提取 N 个 [UNK] 位置的最后层隐状态作为 \(\mathbf{t}_{\text{assist}} \in \mathbb{R}^{N \times d_{\text{assist}}}\)
    • 在连续空间操作,避免了自回归解码的信息损失和梯度截断
  4. 设计动机:[UNK] token 作为"占位符",不携带特定语义,迫使模型将问题的理解压缩到这些位置的隐状态中。辅助模型冻结→不需要额外训练成本

  5. 投影模块:

  6. 做什么:弥合辅助模型和主 LLM 之间的表示空间&维度差距
  7. 核心思路:\(\mathcal{T}_{\text{soft}} = \text{Linear}_\theta(\mathbf{t}_{\text{assist}})\),将 \(\mathbb{R}^{d_{\text{assist}}}\) 映射到 \(\mathbb{R}^{d_{\text{LLM}}}\)
  8. 这是唯一的可训练组件——一个线性层,参数量极小
  9. 设计动机:类似 LLaVA 中视觉编码器到 LLM 的投影,用最小的训练开销桥接两个模型空间

  10. 冻结主 LLM 的推理:

  11. 做什么:利用软思维增强主 LLM 的 CoT 推理
  12. 输入:\(\mathbf{x}_{\text{LLM}} = \text{concat}[\mathcal{I}_{\text{LLM}}, \mathcal{Q}, \mathcal{T}_{\text{soft}}]\)
  13. 主 LLM 完全冻结,根据指令+问题+软思维自回归生成推理链和答案
  14. 训练时:用 NLL 损失监督推理步骤 \(\mathcal{R}\) 和答案 \(\mathcal{A}\),梯度仅回传到投影层

损失函数 / 训练策略

  • 标准语言建模目标(NLL loss),mask 掉指令和问题部分,仅在推理链+答案上计算损失
  • 梯度只更新投影模块的参数
  • 参数效率:主 LLM 和辅助模型均冻结

实验关键数据

主实验(LLaMA-3.1-8B-Instruct 为主 LLM)

方法 GSM8K ASDiv-Aug AQuA StrategyQA Date Und. Avg
Zero-Shot CoT 79.61 86.78 54.65 65.63 54.40 68.21
Zero-Shot CoT-UNK 79.95 86.90 55.28 66.16 54.16 68.49
Zero-Shot Assist-CoT 80.76 86.96 55.83 66.55 58.24 69.67
LoRA Fine-Tuning 75.66 86.67 52.36 - - -
Coconut (LoRA) 76.12 86.80 53.15 - - -
SoftCoT 81.03 87.19 56.30 69.04 59.04 70.52

关键对比:LoRA 微调和 Coconut 在 GSM8K 上都低于 Zero-Shot CoT(75.66/76.12 vs 79.61),证实了灾难性遗忘。SoftCoT 是唯一在所有任务上一致超越 Zero-Shot CoT 的方法。

消融实验

组件 GSM8K ASDiv-Aug 说明
SoftCoT (完整) 81.03 87.19 完整模型
w/o 辅助模型(随机UNK) 79.95 86.90 退化为CoT-UNK,-1.08
w/o 连续空间(硬token) 80.76 86.96 退化为Assist-CoT,-0.27
w/o 投影模块 维度不匹配无法运行

关键发现

  • 灾难性遗忘是真实问题:Coconut 在 GPT-2 上有效,但在 LLaMA-3.1-8B-Instruct 上使用 LoRA 微调反而比 zero-shot 差 3.5 个点
  • [UNK] token 本身就有微弱正面效果("pause token" 效应),增加模型计算容量,减少方差
  • 软思维优于硬思维:SoftCoT (70.52) > Assist-CoT (69.67),连续表示比离散 token 编码更丰富
  • 仅训练一个投影层(参数量极小)就能持续获得收益
  • 在 Qwen2.5-7B-Instruct 上也验证了一致的改进效果

亮点与洞察

  • "不改模型,改输入"的思路很实用:与 prompt tuning 的理念一脉相承,但用辅助模型生成实例自适应的软提示比固定的可学习 prompt 更强。这种设计保证了主 LLM 知识的完整性
  • 灾难性遗忘的系统验证是重要贡献:明确指出 Coconut 类方法不适用于已经很强的指令调优模型,为连续空间推理领域矫正了方向
  • 类似 speculative decoding 的架构设计:小模型"推测"推理方向,大模型"执行",是一种优雅的模型协作范式。可迁移到其他需要辅助推理的场景

局限性 / 可改进方向

  • 软思维的长度 N 是固定的超参数,未根据问题难度自适应调整
  • 投影模块仅用线性层,可能限制了表示映射的能力(MLP 或 cross-attention 可能更好)
  • 辅助模型完全冻结,未探索联合微调辅助模型(可能进一步提升但增加开销)
  • 实验仅在 7B-8B 规模验证,更大模型(70B+)的效果未知
  • 软思维的可解释性较差——无法像离散 CoT 那样阅读推理过程

相关工作与启发

  • vs Coconut:Coconut 需要全模型微调且在强模型上灾难性遗忘;SoftCoT 完全冻结主 LLM,一致有效
  • vs Prompt Tuning:传统 prompt tuning 的可学习 prompt 是静态的,SoftCoT 的软思维是实例自适应的
  • vs LoRA Fine-Tuning:LoRA 在 CoT 任务上表现差于 zero-shot;SoftCoT 通过外部模块避免了这个问题

评分

  • 新颖性: ⭐⭐⭐⭐ 辅助模型生成软思维的设计有创新性,解决了连续CoT在强模型上的灾难性遗忘
  • 实验充分度: ⭐⭐⭐⭐ 5个基准、3类推理、2个LLM、多个baseline对比和消融
  • 写作质量: ⭐⭐⭐⭐ 问题动机清晰,方法描述规范,图表简洁
  • 价值: ⭐⭐⭐⭐ 提供了一种实用的轻量级推理增强方案,参数效率极高