跳转至

Multi-Prompting Decoder Helps Better Language Understanding

会议: ACL2025 arXiv: 2406.06279 代码: 待确认 领域: llm_nlp 关键词: Model-as-a-Service, prompt tuning, optimal transport, few-shot learning, output-side adaptation

一句话总结

提出 Multi-Prompting Decoder(MPD)框架,通过多提示查询 PLM 获取多组隐状态和类别分数,结合最优传输匹配和校准解码策略,在 MaaS(模型即服务)场景下的 few-shot 分类任务上显著超越现有方法。

研究背景与动机

  1. MaaS 部署模式兴起:大规模预训练模型日益以 API 服务形式提供(如 GPT-3.5 Turbo、text-embedding 系列),用户只能获取输出(隐状态、类别分数、文本),无法访问模型参数和梯度。

  2. 输入端适配效率低:在 MaaS 设置下,无梯度优化连续/离散提示(如 BBT、RLPrompt)需要查询 PLM 数千至上万次,搜索空间巨大且优化困难,时间开销极大。

  3. 输出端适配的单提示瓶颈:现有输出端方法(如 DecT)只用单一提示查询 PLM,性能高度依赖提示质量。实验显示不同提示在 SST2 等数据集上的准确率波动可超过 8%。

  4. Few-shot 数据稀缺:少样本设置下每类仅 1-16 个训练样本,单提示获得的表示信息极为有限,进一步放大了提示选择的风险。

  5. PLM 预测偏差:模型倾向于预测预训练分布中常见的 token,导致类别分数存在系统性偏差,直接使用这些分数分类效果不佳。

  6. 多提示的潜力未被挖掘:多提示可以同时缓解单提示依赖、缓解数据稀缺(一个样本获得多组表示)、从不同角度提取 PLM 知识,但缺乏有效的解码机制来利用这些多源信息。

方法详解

整体框架

MPD 包含两个解码策略:(1) 基于最优传输的多提示隐状态解码;(2) 校准的多提示类别分数解码。最终通过联合解码输出预测结果。

关键设计

多提示查询:对每个样本使用 P 个不同模板包装后查询 PLM,获取 P 组隐状态(取最后一层 [MASK] 位置的表示),经线性层投影后得到文本表示矩阵 V_i ∈ R^{P×d}。

基于最优传输的分类:为每个类别 k 维护 Q 个可学习原型 R_{k,n},用最优传输(Sinkhorn 算法)求解文本表示和类别原型之间的最优匹配计划 T^{i,k},OT 分数为匹配计划与余弦相似度的加权和。这种设计让每个提示的表示能与最匹配的原型对齐,避免了粗暴平均或独立分类器的不足。

校准的类别分数解码:(1) 扩展标签词集合(基于 MLM 预测层中词向量的余弦相似度扩展 10 个同义词);(2) 用空输入的类别分数校准偏差;(3) 多提示的校准分数取平均。

联合解码:最终预测为 OT 分数和校准类别分数的加权和,β 为平衡超参数。

损失函数

标准交叉熵损失优化 OT 分数。可学习参数仅有线性层和类别原型,模型极为轻量(约 132K 参数)。

实验关键数据

主实验(Table 1 - 9 个 NLU 数据集,RoBERTa-Large)

设置 方法 SST2 AG DBPedia Yahoo RTE SNLI Avg
1-shot DecT 90.8 79.9 78.8 55.2 56.0 47.7 70.0
1-shot MPD 92.3 83.2 84.4 53.6 57.6 46.6 71.5
4-shot DecT 87.6 81.9 89.1 59.9 56.7 53.2 71.8
4-shot MPD 92.6 85.9 92.8 62.2 59.2 57.1 75.2
16-shot DecT 91.0 86.4 94.6 64.2 59.7 60.5 75.5
16-shot MPD 91.9 87.9 96.7 68.3 61.7 62.4 77.8

MPD 在所有 shot 设置下的 10 个数据集上几乎全面最优,1-shot 平均提升 1.5%,16-shot 平均提升 2.3%。

效率对比(Table 2 - 16-shot)

方法 可训练参数 查询次数 SST2 Acc 训练时间(s)
BBT 0.5K 8,000 89.6 1619
RLPrompt 3100K 12,000 87.0 82286
DecT 130K 1 91.0 1.4
MPD 132K 3 91.9 3.5

MPD 仅需 3 次查询(P=3 个提示)和 3.5 秒训练,比 BBT 快 462 倍,比 RLPrompt 快 23,510 倍。

消融实验关键结论

  • 提示数量:P=3 是最优选择,P=1 退化为单提示;P 过大时边际收益递减
  • 原型数量:Q=3 最优,过多原型在 few-shot 设置下过拟合
  • OT vs 平均:OT 匹配优于简单平均多提示表示(约 1-2% 提升)
  • 校准解码贡献:移除校准分数后性能下降 1-3%,验证了先验知识的互补性
  • 标签词扩展:扩展 10 个标签词比仅用 1 个标签词提升 2-4%(在情感/主题任务上)

亮点与洞察

  1. 极简但有效:核心思想直觉简单——用多个提示查询以获得更稳定的表示——但通过 OT 匹配机制将其充分利用
  2. 效率极高:相比输入端方法需数千次查询,MPD 仅需 P 次查询和数秒训练,非常适合实际 MaaS 场景
  3. OT 匹配的妙用:不是简单地融合多提示信息,而是为每个提示找到最匹配的原型,保留了提示特异性
  4. 提示鲁棒性:论文 Figure 1 显示单提示波动超 8%,而 MPD 的标准差显著降低(如 SST2 16-shot 从 0.5 降至 0.1)

局限性

  1. 仅在 RoBERTa-Large 上实验,未验证在更大模型或 decoder-only 架构(如 GPT 系列)上的效果
  2. 模板需手动设计,虽然结果对模板不敏感但仍非完全自动化
  3. NLI 任务上标签词扩展效果有限(标签词本身语义较抽象)
  4. 未探索当 PLM 提供文本输出(而非隐状态/logits)时的适用性
  5. β 超参数对 MNLI 需单独调优,自动化程度有提升空间

相关工作与启发

  • 与 DecT 的关系:MPD 是 DecT 的自然扩展——从单提示到多提示,从超球原型到 OT 匹配
  • 与 PromptBoosting 的差异:PromptBoosting 通过 Boosting 集成多个弱学习器,是"模型集成";MPD 是"表示融合",更高效
  • OT 在 NLP 中的应用:将 OT 引入 MaaS 适配是新颖的,可启发其他需要多视角表示匹配的任务
  • 启发:多提示思想可以推广到 LLM 的黑盒推理场景,如对 ChatGPT API 使用多提示查询后集成,可能提升稳定性

评分

  • 新颖性: ⭐⭐⭐⭐ (多提示+OT解码的组合有新意)
  • 实验充分度: ⭐⭐⭐⭐ (9个数据集、3个shot设置、充分消融)
  • 写作质量: ⭐⭐⭐⭐ (方法阐述清晰,图表直观)
  • 价值: ⭐⭐⭐⭐ (对MaaS实际应用有明确指导意义)