Enhancing Chain-of-Thought Reasoning with Critical Representation Fine-tuning¶
会议: ACL 2025
arXiv: 2507.10085
代码: 无
领域: LLM推理
关键词: 表征微调, 链式推理, 参数高效微调, 信息流分析, 关键表征
一句话总结¶
提出 CRFT 方法,通过信息流分析自动识别 Transformer 各层中对推理输出影响最大的"关键表征",并在低秩线性子空间中对这些表征进行有监督优化,在仅使用模型 0.016% 参数的情况下,将 LLaMA-2-7B 在 GSM8K 上的准确率提升了 18.2%。
研究背景与动机¶
领域现状:大语言模型在复杂推理任务中取得了显著进展,链式推理(Chain-of-Thought, CoT)是其中的核心技术,通过将推理过程分解为多步中间步骤来提升模型的推理能力。参数高效微调(PEFT)方法如 LoRA 已被广泛应用于适配 LLM 到下游任务。
现有痛点:表征微调(ReFT)作为一类新兴的 PEFT 方法,通过直接编辑模型表征空间来实现参数高效性。然而 ReFT 在复杂推理任务上表现不佳,因为它修改的是每层开头和末尾的固定位置表征,而这些固定位置的表征对输出的影响是不确定的——有些表征可能对推理毫无帮助,而真正关键的表征可能被忽略。
核心矛盾:ReFT 的"固定位置选择"策略与推理任务中"表征重要性是上下文相关"的事实之间存在根本矛盾。在复杂推理中,每一层都存在一些真正关键的表征——它们要么聚合了前续层的重要信息,要么调控着后续层的表征。但哪些表征是关键的,取决于具体的输入和推理链路,无法通过固定规则确定。
本文目标:(1) 自动识别 Transformer 每层中对推理输出影响最大的关键表征;(2) 以极少的可训练参数对这些表征进行针对性优化;(3) 在多种推理场景和模型上验证方法的通用性。
切入角度:作者观察到,对 LLaMA-2-7B 每层的随机表征添加微小高斯噪声(0.01)就会导致 GSM8K 准确率下降 1.4%,说明模型表现对特定表征极其敏感。进一步分析发现,信息流(attention score 和 saliency score)能有效揭示哪些表征承担关键角色。
核心 idea:利用信息流分析(注意力分数和显著性分数)动态识别每层的关键表征,然后在低秩子空间中学习自适应更新方向来优化这些表征,实现轻量级但精准的推理增强。
方法详解¶
整体框架¶
CRFT 的整体流程分为两个阶段:识别和优化。输入为 token 序列经过 embedding 后的表征序列,经过 \(L\) 层 Transformer 逐层计算。在每一层,CRFT 首先利用信息流分析(注意力分数或显著性分数)筛选出关键表征的集合 \(M(h)\),然后对集合中的表征施加低秩线性投影修正,冻结基础模型参数不变。最终模型利用修正后的最后一层表征生成推理答案。
关键设计¶
-
自引用过滤(Self-Referential Filtering):
- 功能:识别那些内部信息高度自聚合的关键表征
- 核心思路:如果表征 \(i\) 在第 \(l\) 层的信息主要回流到自身(即 \(\text{Info}(i,i)\) 值较大),说明该表征有效积累了重要信息。由于 softmax 归一化,自引用比例高意味着对外信息传播少,表征成为信息"汇聚点"。具体有两种度量方式:自引用注意力过滤(SAF),直接用注意力对角元素 \(A_i^{(l)}\) 衡量;自引用显著性过滤(SSF),用注意力与梯度的 Hadamard 积进行衡量,综合考虑信息流方向和对输出的敏感度
- 设计动机:在推理任务中,承载关键中间计算结果的表征往往会把信息聚合在自身位置,通过检测这种"信息自聚合"模式可以精准定位关键节点
-
多引用过滤(Multi-Referential Filtering):
- 功能:识别对下游多个表征产生大范围调控影响的关键表征
- 核心思路:如果表征 \(j\) 对多个其他表征有显著影响(即列平均值 \(\frac{1}{n+k-j+1}\sum_i \text{Info}(i,j)\) 超过阈值 \(\beta\)),则表征 \(j\) 是关键的"调控者"。同样有 MAF(注意力版本)和 MSF(显著性版本)两种实现。此外还可以将自引用和多引用两类关键表征取并集(Union 策略),避免遗漏
- 设计动机:与自引用互补——有些关键表征不是信息汇聚者,而是信息广播者,它们向后续层的大量位置传播信息,是推理链路中的"中枢"节点
-
低秩子空间优化:
- 功能:在冻结原始模型的前提下,学习对关键表征的自适应修正方向
- 核心思路:对每个被识别为关键的表征 \(h\),学习一个修正量 \(\Delta h = R^T(Wh + b - Rh)\),其中 \(R \in \mathbb{R}^{r \times d}\) 是行正交的投影矩阵,\(W\) 和 \(b\) 是可学习参数。修正被限制在 \(r\) 维低秩子空间中,保证参数量极小。非关键表征保持不变,只有 \(M(h)\) 内的表征被修改
- 设计动机:关键表征的修正方向因上下文而异,需要通过有监督学习来自适应确定。低秩约束既控制参数量(仅 0.016%),又起到正则化作用防止过拟合
损失函数 / 训练策略¶
训练使用标准的交叉熵损失,基于 CoT 推理步骤的监督学习框架。数学推理用 Math10K 数据集训练,常识推理用自建的 Commonsense60K 数据集(包含推理步骤)训练。所有实验使用 AdamW 优化器,默认 rank=8,每层 14 个干预表征,阈值 \(\alpha=\beta=0.05\),选择标准为"位置排序"。
实验关键数据¶
主实验¶
| 方法 | 可训练参数(%) | GSM8K准确率 |
|---|---|---|
| LLaMA-2-7B (无微调) | - | 14.6% |
| LoRA (r=64) | 0.826% | 38.5% |
| LoRA (r=8) | 0.103% | 36.7% |
| ReFT (p7+s7) | 0.031% | 29.0% |
| CRFT-Union(attn) | 0.016% | 32.8% |
| CRFT-MAF | 0.016% | 32.1% |
| CRFT-SAF | 0.016% | 30.4% |
跨模型跨数据集结果(算术推理 + 常识推理):
| 模型 | 方法 | AQuA | MAWPS | SVAMP | BoolQ | SocialIQA | WinoGrande | OpenBookQA |
|---|---|---|---|---|---|---|---|---|
| LLaMA-2-7B | ReFT | 21.7 | 80.7 | 52.2 | 50.7 | 61.2 | 51.7 | 58.6 |
| LLaMA-2-7B | CRFT-MAF | 27.6 | 81.1 | 53.4 | 60.5 | 52.8 | 68.4 | 66.4 |
| LLaMA-3-8B | ReFT | 46.9 | 87.0 | 74.2 | 62.1 | 60.2 | 56.0 | 66.0 |
| LLaMA-3-8B | CRFT-SSF | 50.0 | 86.6 | 78.1 | 66.6 | 74.7 | 62.0 | 77.0 |
| Mistral-7B | ReFT | 32.3 | 84.9 | 67.4 | 62.5 | 64.6 | 58.5 | 63.8 |
| Mistral-7B | CRFT-MSF | 41.3 | 87.4 | 66.9 | 65.0 | 71.8 | 62.3 | 72.8 |
消融实验¶
| 配置 | GSM8K准确率 | 说明 |
|---|---|---|
| 阈值 \(\alpha=1.0\) | 24.7% | 阈值过高,选中表征太少 |
| 阈值 \(\alpha=0.25\) | 30.0% | 中等阈值 |
| 阈值 \(\alpha=0.05\)(默认) | 29.6% | 平衡点 |
| 阈值 \(\alpha=0.01\) | 33.2% | 更多表征被选中,效果最优 |
| 位置排序选择 | 29.6% | 默认策略 |
| 分数排序选择 | 28.7% | 略低 |
| 随机选择 | 23.1% | 显著下降,证明非随机干预 |
| 仅干预 Layer 0 | 24.9% | 仅前端层有一定作用 |
| 仅干预 Layer 31 | 22.7% | 最后一层效果有限 |
| 干预全部层 | 29.6% | 完整干预效果最好 |
关键发现¶
- 关键表征识别至关重要:对关键表征添加 0.02 噪声后,准确率从正确样本的 100% 降至 21.1%,而非关键表征仅降至 74.1%,验证了关键表征的影响力
- 不同策略各有优势:SAF 和 MAF 从不同角度捕获关键表征,Union 策略不需要手工选择就能稳定获得较优结果
- 极限参数效率:CRFT 仅用 LoRA 1/6 的参数和 ReFT 1/2 的参数,就能在多数基准上达到可比甚至更优的性能
- Few-shot 扩展:one-shot 场景下准确率提升 16.4%,且 demonstration 和 question 使用独立更新向量效果更好
亮点与洞察¶
- 信息流驱动的关键性判断:不同于 ReFT 依赖经验规则选定位置,CRFT 用信息流分析为"哪些表征值得优化"提供了理论依据和自动化手段。这个思路可以迁移到任何需要定位关键中间状态的场景——例如在视觉 Transformer 中识别关键 patch 表征后做高效微调
- "注意力汇聚"与"信息广播"的双重视角:同时考虑自引用(信息聚合者)和多引用(信息传播者)两类角色,形成互补的关键表征集合。这个框架思想比单纯看 attention weight 更全面
- 噪声扰动验证策略:通过给表征添加微小噪声观察输出变化来验证选到的确实是关键表征,这是一种简洁有效的可解释性验证方法
局限与展望¶
- 只关注正面影响的表征:目前方法侧重于找到对输出影响大的表征,但未区分正面/负面影响。优先修正具有负面影响的表征可能更高效
- 优化空间限于线性:修正量被限制在低秩线性子空间,可能遗漏非线性方向的潜在优化空间
- 单 GPU 训练限制:few-shot 实验因显存限制只做到 two-shot,更长的 demonstration 场景有待探索
- 策略选择仍需实验:虽然 Union 策略较稳定,但最优策略因模型和任务而异,缺乏统一的自动选择机制
相关工作与启发¶
- vs ReFT:ReFT 修改每层首尾固定位置的表征,位置选择依赖在其他数据集上的试错,缺乏可解释性。CRFT 通过信息流动态识别关键位置,且仅需一半参数即可超越 ReFT
- vs LoRA:LoRA 修改权重矩阵,参数量为 CRFT 的 6 倍以上。CRFT 在表征层面直接干预,理论上更精准但适用场景可能更窄
- vs PASTA:PASTA 手动定义需要增强注意力的 token,CRFT 自动化了这一过程并扩展到整个表征空间的优化
评分¶
- 新颖性: ⭐⭐⭐⭐ 将信息流分析引入表征微调的关键位置选择,思路新颖但在 ReFT 基础上的增量性质明显
- 实验充分度: ⭐⭐⭐⭐⭐ 8 个数据集、4 个模型、多种消融、噪声验证、注意力可视化,实验非常全面
- 写作质量: ⭐⭐⭐⭐ 结构清晰,消融实验层层递进,但公式符号较多导致阅读门槛偏高
- 价值: ⭐⭐⭐⭐ 为参数高效微调在推理任务上提供了新视角,0.016% 参数量的效率值得关注
相关论文¶
- [ACL 2025] TRACT: Regression-Aware Fine-tuning Meets Chain-of-Thought Reasoning
- [ACL 2025] Fine-Tuning on Diverse Reasoning Chains Drives Within-Inference CoT Refinement in LLMs
- [ACL 2025] CoT-Valve: Length-Compressible Chain-of-Thought Tuning
- [ACL 2025] Unlocking General Long Chain-of-Thought Reasoning Capabilities of Large Language Models via Representation Engineering
- [ACL 2025] MM-Verify: Enhancing Multimodal Reasoning with Chain-of-Thought Verification