跳转至

Leveraging Self-Attention for Input-Dependent Soft Prompting in LLMs

会议: ACL 2025
arXiv: 2506.05629
代码: 无
领域: LLM/NLP
关键词: 软提示, 参数高效微调, 自注意力, 输入依赖, PEFT, Prompt Tuning

一句话总结

提出 ID-SPAM,通过在输入 token 嵌入上施加可学习自注意力层并经瓶颈 MLP 生成输入依赖的软提示,仅在单层 Transformer 输入端拼接即可超越多种 Soft Prompt 基线,且具备优秀的零样本跨任务/跨领域迁移能力。

研究背景与动机

  1. 大语言模型(LLM)在领域特定任务中需要微调,但全参数微调计算成本极高(BERT 到 GPT-3 参数量从亿到千亿级别),参数高效微调(PEFT)因此成为研究热点。
  2. 软提示(Soft Prompting)是一类有前景的 PEFT 方法——在冻结 LM 参数的前提下,仅学习一组小规模的连续向量(软提示)来适配下游任务,避免修改模型核心架构。
  3. 现有软提示方法(Prompt Tuning、Prefix Tuning、P-Tuning 等)的提示向量与输入无关,即所有样本共享同一组提示参数。这限制了模型在推理时根据不同输入动态调整的能力,也增加了训练收敛难度。
  4. 已有输入依赖的软提示方法存在多方面不足:(a) 需要在 LM 的多个 Transformer 层拼接软提示(架构复杂);(b) 未显式地对输入中不同 token 赋予不同重要性权重;(c) 可训练参数量显著增加。
  5. 一个自然的思路是:既然任务样本包含多样化的词汇,那么在生成软提示时就应该差异化地关注不同输入 token——这正是自注意力机制的核心能力。
  6. 本文提出 ID-SPAM(Input Dependent Soft Prompting with self-Attention Mechanism),用一个可学习的自注意力层聚合输入信息,再经瓶颈 MLP 映射为软提示,仅在单个 Transformer 层拼接,参数量小且训练平滑。

方法详解

整体框架

给定一个下游任务 \(T\),训练数据 \(D_{train} = \{(x_i, y_i)\}_{i=1}^{K}\)。对于单句任务,输入表示为 \(x_i = \mathbf{E}(\texttt{[SEP]}S_1\texttt{[EOS]})\);对于句对任务,\(x_i = \mathbf{E}(\texttt{[SEP]}S_1\texttt{[SEP]}S_2\texttt{[EOS]})\),其中 \(\mathbf{E}(\cdot)\) 为 token 嵌入层。ID-SPAM 通过三个阶段生成输入依赖的软提示 \(\mathbf{S}_T \in \mathbb{R}^{n \times t}\)\(n\) 为隐藏维度,\(t\) 为提示 token 数),然后将其拼接到 LM 某一 Transformer 层的输入端,冻结 LM 全部参数仅训练软提示生成网络。

关键设计

模块一:自注意力聚合层

  • 做什么:对输入嵌入 \(\mathbf{E}\) 施加单头自注意力,再沿 token 维度取均值,得到上下文丰富的 \(n \times 1\) 维向量 \(A\)
  • 核心思路: $\(A = \text{mean}\left\{\text{softmax}\left(\frac{(\mathbf{E}W_Q)(\mathbf{E}W_K)^\top}{\sqrt{d_k}}\right)(\mathbf{E}W_V)\right\}\)$ 其中 \(W_Q, W_K, W_V\) 为可学习的查询/键/值投影矩阵,\(\frac{1}{\sqrt{d_k}}\) 为缩放因子。
  • 设计动机:不同 token 对任务的贡献不同(如情感分类中"excellent"比"the"更关键),自注意力可以自动学习 token 级别的重要性加权,使得生成的软提示能捕获输入中的关键语义信号

模块二:瓶颈 MLP(Down-Up Projection)

  • 做什么:将聚合向量 \(A\) 先下投影到低维空间 \(c\)\(c < n\)),经 ReLU 激活后再上投影到 \(n \cdot t\) 维,最后 reshape 为 \(\mathbf{S}_T \in \mathbb{R}^{n \times t}\)
  • 核心思路: $\(\mathbf{S}_T = \text{resize}\left(\sigma(W_{up} \cdot \sigma(W_{down} \cdot A))\right)\)$ 其中 \(W_{down} \in \mathbb{R}^{n \times c}\)\(W_{up} \in \mathbb{R}^{c \times (n \cdot t)}\)\(\sigma\) 为 ReLU。
  • 设计动机:瓶颈结构(类似 LoRA 的低秩思想)大幅压缩参数量,同时引入非线性变换提升表达能力。通过低维中间表示作为信息压缩,避免过拟合。

模块三:单层拼接策略

  • 做什么:将生成的软提示 \(\mathbf{S}_T\) 拼接到 LM 的单个 Transformer 层(第 \(m\) 层)的输入端,而非在多层或所有层拼接。
  • 核心思路:实验发现拼接在中间层(如第 6-8 层)效果最优;早期层效果也不错,因为软提示由输入嵌入生成,与早期层输出的兼容性更好。
  • 设计动机:(1) 减少架构复杂度,避免在每层都引入额外拼接操作;(2) 降低可训练参数量——与 Prefix Tuning(每层拼接)相比参数大幅减少;(3) 使训练过程更平滑,降低收敛难度。

损失函数/训练策略

  • 使用标准交叉熵损失进行训练,Adam 优化器。
  • 冻结基座 LM 全部参数,仅训练自注意力层参数(\(W_Q, W_K, W_V\))和瓶颈 MLP 参数(\(W_{down}, W_{up}\) 及偏置)。
  • 软提示 token 数 \(t = 10\),训练最多 30 个 epoch。
  • 实验使用 NVIDIA A100 80GB GPU。

实验关键数据

主实验:GLUE Benchmark(RoBERTa-LARGE 骨干)

方法 MNLI QNLI SST-2 MRPC RTE QQP 均值
Fine-tuning 87.6 94.7 95.4 92.1 88.4 90.7 91.5
LoRA 89.1 87.9 95.1 86.5 78.7 88.4 87.6
Prompt Tuning 83.4 88.2 92.6 73.9 60.8 81.2 80.0
P-Tuning 86.4 88.7 95.8 76.3 62.6 85.2 82.5
SMoP 86.7 88.4 95.8 79.6 76.3 86.7 85.6
LPT 84.2 86.1 93.4 87.3 74.2 85.3 85.1
DePT 83.3 88.8 91.2 77.7 73.2 82.2 82.7
ID-SPAM 87.4 91.1 94.6 86.1 81.1 88.4 88.1

消融实验:自注意力 vs 均值池化(RoBERTa-LARGE)

方法 MRPC RTE QQP
Mean-pooling 82.3 75.2 84.2
ID-SPAM 86.1 81.1 88.4

零样本跨任务/跨领域迁移(RoBERTa-LARGE)

方法 QQP→MRPC MRPC→QQP SST-2→IMDB IMDB→SST-2
Fine-tuning 64.0 68.3 87.1 88.8
LoRA 71.1 66.1 90.3 87.6
LPT 66.7 64.5 67.1 71.1
ID-SPAM 70.9 69.2 89.1 86.0

关键发现

  • ID-SPAM 在 GLUE 6 个任务中的 4 个超越所有 Soft Prompt 基线(RoBERTa-BASE 和 LARGE 骨干均如此),均值分数大幅领先。
  • SuperGLUE 4 个任务中,使用 RoBERTa-LARGE 骨干时 ID-SPAM 在 3/4 任务最优,均值 72.0(LPT 70.2、SMoP 70.4)。
  • 消融实验表明,自注意力层带来平均 5.82% 的性能提升(相比直接均值池化),验证了对不同 token 差异化加权的重要性。
  • 零样本迁移中 ID-SPAM 在 4 个迁移对中全部优于其他 Soft Prompt 方法,甚至在 3/4 对中超越全参数 Fine-tuning,显示出优秀的泛化能力。
  • 层选择分析:软提示拼接在中间层效果最好;ID-SPAM 在几乎所有层位置均显著优于 LPT,且对早期层更友好。
  • ID-SPAM 的可训练参数量和训练/推理时间均优于或持平 LPT 和 LoRA(详见论文附录 D)。

亮点与洞察

  1. 简洁有效的设计哲学:仅用一个自注意力层 + 瓶颈 MLP + 单层拼接,就实现了输入依赖的软提示生成——既避免了多层拼接的复杂性,又保持了极低的参数量。
  2. 自注意力赋予 token 级别选择性:不同于以往将所有 token 等权处理的方法,ID-SPAM 能自动识别对任务关键的 token 并给予更高权重,这是其在情感分类、自然语言推理等多样任务上表现稳定的根本原因。
  3. 零样本迁移能力强劲:输入依赖的提示生成天然具备泛化性——提示随输入变化,因此在分布偏移场景下能更灵活地适应,而固定提示方法则容易过拟合训练域分布。
  4. 与 LoRA 形成互补视角:LoRA 通过低秩适配权重矩阵,ID-SPAM 通过输入依赖的软提示——两者都追求参数效率但路径不同,ID-SPAM 在多数任务上可与 LoRA 匹敌甚至更优。

局限性/可改进方向

  1. 骨干模型规模有限:实验仅在 RoBERTa-BASE/LARGE(125M/355M)和 GPT-2 上验证,未能在 LLaMA-3.1-70B、Mixtral 8×22B 等大规模模型上测试——无法确定该方法在真正的大模型上是否仍有优势。
  2. 层选择为手动超参:拼接到哪一层 Transformer 需要人工搜索,缺乏自动选择最优层的机制。未来可考虑引入可微的层路由(如 Gumbel-Softmax 选层)或同时在多层加权融合。
  3. 仅限 NLU 任务:评估集中在分类/推理任务(GLUE/SuperGLUE),未涉及生成任务(摘要、翻译、对话等),方法在生成场景下的表现未知。
  4. 单头注意力:当前仅使用单头自注意力,多头注意力可能捕获更丰富的 token 交互模式,值得探索。
  5. 与其他 PEFT 方法的组合:ID-SPAM 与 LoRA、Adapter 等方法是正交的,组合使用可能带来额外增益,但论文未探索。

相关工作与启发

  • Prompt Tuning (Lester et al., 2021) 和 Prefix Tuning (Li & Liang, 2021) 是软提示方法的基石工作,分别在嵌入层和每层拼接固定软提示。
  • LPT (Liu et al., 2022a) 提出"迟到的提示"——仅在中间层注入,是 ID-SPAM 的直接对比对象。
  • SMoP (Choi et al., 2023) 使用多个短提示 + 门控路由,思路是"为不同数据子集匹配不同提示",与 ID-SPAM 的"每个样本生成专属提示"形成对比。
  • DePT (Shi & Lipani, 2024) 通过低秩分解压缩软提示参数,与 ID-SPAM 的瓶颈 MLP 有异曲同工之处。
  • 启发:输入依赖 + 注意力加权是一个通用的设计模式,可推广到 Adapter、LoRA 等其他 PEFT 方法中——例如根据输入动态生成 LoRA 的低秩矩阵。

评分

  • 新颖性: ⭐⭐⭐ — 自注意力 + 瓶颈 MLP 生成软提示的思路直观清晰,但整体架构创新幅度不大,核心组件均为已有模块的组合。
  • 技术质量: ⭐⭐⭐⭐ — 实验覆盖 GLUE、SuperGLUE、零样本迁移,基线全面,消融实验清晰验证了自注意力的作用。
  • 实用价值: ⭐⭐⭐ — 方法简单易实现,参数效率高,但仅在中小模型上验证,对当下主流大模型场景的适用性存疑。
  • 表达清晰度: ⭐⭐⭐⭐ — 论文结构清晰,公式推导完整,图示直观,实验表格规范。