Multi-Prompting Decoder Helps Better Language Understanding¶
会议: ACL2025 arXiv: 2406.06279 代码: 待确认 领域: llm_nlp 关键词: Model-as-a-Service, prompt tuning, optimal transport, few-shot learning, output-side adaptation
一句话总结¶
提出 Multi-Prompting Decoder(MPD)框架,通过多提示查询 PLM 获取多组隐状态和类别分数,结合最优传输匹配和校准解码策略,在 MaaS(模型即服务)场景下的 few-shot 分类任务上显著超越现有方法。
研究背景与动机¶
-
MaaS 部署模式兴起:大规模预训练模型日益以 API 服务形式提供(如 GPT-3.5 Turbo、text-embedding 系列),用户只能获取输出(隐状态、类别分数、文本),无法访问模型参数和梯度。
-
输入端适配效率低:在 MaaS 设置下,无梯度优化连续/离散提示(如 BBT、RLPrompt)需要查询 PLM 数千至上万次,搜索空间巨大且优化困难,时间开销极大。
-
输出端适配的单提示瓶颈:现有输出端方法(如 DecT)只用单一提示查询 PLM,性能高度依赖提示质量。实验显示不同提示在 SST2 等数据集上的准确率波动可超过 8%。
-
Few-shot 数据稀缺:少样本设置下每类仅 1-16 个训练样本,单提示获得的表示信息极为有限,进一步放大了提示选择的风险。
-
PLM 预测偏差:模型倾向于预测预训练分布中常见的 token,导致类别分数存在系统性偏差,直接使用这些分数分类效果不佳。
-
多提示的潜力未被挖掘:多提示可以同时缓解单提示依赖、缓解数据稀缺(一个样本获得多组表示)、从不同角度提取 PLM 知识,但缺乏有效的解码机制来利用这些多源信息。
方法详解¶
整体框架¶
MPD 包含两个解码策略:(1) 基于最优传输的多提示隐状态解码;(2) 校准的多提示类别分数解码。最终通过联合解码输出预测结果。
关键设计¶
多提示查询:对每个样本使用 P 个不同模板包装后查询 PLM,获取 P 组隐状态(取最后一层 [MASK] 位置的表示),经线性层投影后得到文本表示矩阵 V_i ∈ R^{P×d}。
基于最优传输的分类:为每个类别 k 维护 Q 个可学习原型 R_{k,n},用最优传输(Sinkhorn 算法)求解文本表示和类别原型之间的最优匹配计划 T^{i,k},OT 分数为匹配计划与余弦相似度的加权和。这种设计让每个提示的表示能与最匹配的原型对齐,避免了粗暴平均或独立分类器的不足。
校准的类别分数解码:(1) 扩展标签词集合(基于 MLM 预测层中词向量的余弦相似度扩展 10 个同义词);(2) 用空输入的类别分数校准偏差;(3) 多提示的校准分数取平均。
联合解码:最终预测为 OT 分数和校准类别分数的加权和,β 为平衡超参数。
损失函数¶
标准交叉熵损失优化 OT 分数。可学习参数仅有线性层和类别原型,模型极为轻量(约 132K 参数)。
实验关键数据¶
主实验(Table 1 - 9 个 NLU 数据集,RoBERTa-Large)¶
| 设置 | 方法 | SST2 | AG | DBPedia | Yahoo | RTE | SNLI | Avg |
|---|---|---|---|---|---|---|---|---|
| 1-shot | DecT | 90.8 | 79.9 | 78.8 | 55.2 | 56.0 | 47.7 | 70.0 |
| 1-shot | MPD | 92.3 | 83.2 | 84.4 | 53.6 | 57.6 | 46.6 | 71.5 |
| 4-shot | DecT | 87.6 | 81.9 | 89.1 | 59.9 | 56.7 | 53.2 | 71.8 |
| 4-shot | MPD | 92.6 | 85.9 | 92.8 | 62.2 | 59.2 | 57.1 | 75.2 |
| 16-shot | DecT | 91.0 | 86.4 | 94.6 | 64.2 | 59.7 | 60.5 | 75.5 |
| 16-shot | MPD | 91.9 | 87.9 | 96.7 | 68.3 | 61.7 | 62.4 | 77.8 |
MPD 在所有 shot 设置下的 10 个数据集上几乎全面最优,1-shot 平均提升 1.5%,16-shot 平均提升 2.3%。
效率对比(Table 2 - 16-shot)¶
| 方法 | 可训练参数 | 查询次数 | SST2 Acc | 训练时间(s) |
|---|---|---|---|---|
| BBT | 0.5K | 8,000 | 89.6 | 1619 |
| RLPrompt | 3100K | 12,000 | 87.0 | 82286 |
| DecT | 130K | 1 | 91.0 | 1.4 |
| MPD | 132K | 3 | 91.9 | 3.5 |
MPD 仅需 3 次查询(P=3 个提示)和 3.5 秒训练,比 BBT 快 462 倍,比 RLPrompt 快 23,510 倍。
消融实验关键结论¶
- 提示数量:P=3 是最优选择,P=1 退化为单提示;P 过大时边际收益递减
- 原型数量:Q=3 最优,过多原型在 few-shot 设置下过拟合
- OT vs 平均:OT 匹配优于简单平均多提示表示(约 1-2% 提升)
- 校准解码贡献:移除校准分数后性能下降 1-3%,验证了先验知识的互补性
- 标签词扩展:扩展 10 个标签词比仅用 1 个标签词提升 2-4%(在情感/主题任务上)
亮点与洞察¶
- 极简但有效:核心思想直觉简单——用多个提示查询以获得更稳定的表示——但通过 OT 匹配机制将其充分利用
- 效率极高:相比输入端方法需数千次查询,MPD 仅需 P 次查询和数秒训练,非常适合实际 MaaS 场景
- OT 匹配的妙用:不是简单地融合多提示信息,而是为每个提示找到最匹配的原型,保留了提示特异性
- 提示鲁棒性:论文 Figure 1 显示单提示波动超 8%,而 MPD 的标准差显著降低(如 SST2 16-shot 从 0.5 降至 0.1)
局限性¶
- 仅在 RoBERTa-Large 上实验,未验证在更大模型或 decoder-only 架构(如 GPT 系列)上的效果
- 模板需手动设计,虽然结果对模板不敏感但仍非完全自动化
- NLI 任务上标签词扩展效果有限(标签词本身语义较抽象)
- 未探索当 PLM 提供文本输出(而非隐状态/logits)时的适用性
- β 超参数对 MNLI 需单独调优,自动化程度有提升空间
相关工作与启发¶
- 与 DecT 的关系:MPD 是 DecT 的自然扩展——从单提示到多提示,从超球原型到 OT 匹配
- 与 PromptBoosting 的差异:PromptBoosting 通过 Boosting 集成多个弱学习器,是"模型集成";MPD 是"表示融合",更高效
- OT 在 NLP 中的应用:将 OT 引入 MaaS 适配是新颖的,可启发其他需要多视角表示匹配的任务
- 启发:多提示思想可以推广到 LLM 的黑盒推理场景,如对 ChatGPT API 使用多提示查询后集成,可能提升稳定性
评分¶
- 新颖性: ⭐⭐⭐⭐ (多提示+OT解码的组合有新意)
- 实验充分度: ⭐⭐⭐⭐ (9个数据集、3个shot设置、充分消融)
- 写作质量: ⭐⭐⭐⭐ (方法阐述清晰,图表直观)
- 价值: ⭐⭐⭐⭐ (对MaaS实际应用有明确指导意义)