跳转至

ATPO: Adaptive Tree Policy Optimization for Multi-Turn Medical Dialogue

会议: ICLR 2026
arXiv: 2603.02216
代码: https://github.com/Quark-Medical/ATPO
领域: 医学对话 / 强化学习
关键词: 多轮医疗对话, 树搜索, 策略优化, 不确定性引导, 层级MDP, 值函数估计, LLM对齐

一句话总结

提出 ATPO(自适应树策略优化)算法,将多轮医疗对话建模为层级马尔可夫决策过程(H-MDP),通过不确定性感知的自适应树扩展机制动态分配rollout预算,结合Bellman误差和动作值方差的复合不确定性度量来引导探索,在三个医学对话基准上以Qwen3-8B超越GPT-4o。

研究背景与动机

  1. 领域现状:医学大语言模型在单轮问答(如医学考试、疾病诊断)中已达到SOTA水平,但现实医疗对话中用户初始信息通常不完整,需要模型主动追问以收集关键信息。
  2. 现有痛点
  3. Prompt工程方法(如MEDIQ)让模型主动提问反而降低准确率
  4. SFT方法仅模仿训练数据的表面模式,泛化能力差
  5. 轨迹级偏好优化依赖昂贵的偏好数据且对分布偏移敏感
  6. GRPO在长horizon任务中难以进行有效的信用分配
  7. PPO的值函数估计在多轮对话场景下不稳定
  8. 核心矛盾:多轮医疗对话本质上是长horizon序列决策问题,现有RL方法要么在信用分配上失效(GRPO将整条轨迹共享同一优势值),要么值估计不准确(PPO的单步critic在长对话中误差累积)。
  9. 本文要解决什么:如何在多轮医疗对话中实现高效且准确的策略优化——既能精确估计每轮对话的价值,又能高效探索对话空间。
  10. 切入角度:将问题建模为H-MDP,在对话轮次级别进行树搜索,用不确定性度量自适应地分配计算预算。
  11. 核心idea一句话:通过Bellman误差和Q值方差的复合度量识别高不确定性对话状态,选择性扩展树节点,同时提升采样多样性和critic准确性。

方法详解

整体框架

ATPO将多轮对话过程视为搜索树的扩展:初始用户查询为根节点,每个节点代表一个对话状态。在每个非终端节点处,助手模型生成N个候选宏动作(追问或最终回答),计算复合不确定性分数决定是否完全扩展(保留所有N个分支)或剪枝(仅保留1个分支)。收集的轨迹用于策略和critic模型的更新。

关键设计

  1. 层级MDP建模:
  2. 做什么:将多轮对话分为高层MDP和低层MDP两个层次
  3. 核心思路:高层MDP中,宏动作 \(y_k\) 定义为助手在第 \(k\) 轮的完整token序列;低层MDP中,微动作 \(y_{k,t}\) 对应单个token。状态 \(x_k\) 包含第 \(k\) 轮之前的交互历史和用户查询 \(q_k\)
  4. 为什么有效:轮次级别的信用分配比token级别更适合多轮对话——一轮对话内的所有token共享同一个宏动作优势,避免了token级别的稀疏奖励问题

  5. 复合不确定性度量:

  6. 做什么:为每个前沿节点计算不确定性分数,决定是否扩展
  7. 核心思路:对状态 \(x_k\) 采样N个候选宏动作 \(\{y_k^i\}_{i=1}^N\),通过一步前瞻计算动作值 \(\hat{Q}(x_k, y_k^i) = r(x_k, y_k^i) + \gamma V_\psi(x_{k+1}^i)\),然后计算:
    • Bellman误差 \(U_1\):critic当前值估计与经验一步前瞻值之差的绝对值,反映值函数估计的不准确性
    • Q值方差 \(U_2\):候选动作值估计的方差,反映策略的不确定性和环境的随机性
    • 复合分数 \(U = \alpha U_1 + (1-\alpha) U_2\)\(\alpha=0.3\) 平衡两个信号
  8. 为什么有效:\(U_1\) 识别critic不准确的状态(需要更多样本改善值估计),\(U_2\) 识别策略犹豫不决的状态(需要更多探索),两者互补——单独用 \(U_1\) 会导致激进的早期探索集中在浅层,叠加 \(U_2\) 实现更深更均匀的覆盖

  9. 阈值驱动的剪枝策略:

  10. 做什么:根据不确定性阈值 \(\tau\) 决定节点扩展方式
  11. 核心思路:\(U(x_k) > \tau\) 时保留全部N个分支;\(U(x_k) \leq \tau\) 时随机保留1个分支(但以10%概率绕过剪枝以维持基线多样性)。扩展持续到所有对话终止或叶节点数达到预算上限
  12. 为什么有效:自适应分配避免了TreePO固定二叉扩展导致的节点指数增长集中在早期轮次的问题

  13. 值回溯与树分解:

  14. 做什么:从叶节点递归计算所有节点的目标值和优势
  15. 核心思路:叶节点目标值 \(\hat{V}\) 等于即时奖励;非叶节点取所有子节点一步TD目标的平均值。优势采用标准一步TD公式 \(\hat{A} = r + \gamma V_\psi(x_{k+1}) - V_\psi(x_k)\),使用critic估计值而非目标值(因为剪枝节点只有一个分支时目标值会导致零优势)
  16. 为什么有效:树结构的值回溯提供比纯Monte Carlo(GRPO)更低方差的值估计,同时比单一critic(PPO)更准确

  17. PPO风格策略更新与访问计数归一化:

  18. 做什么:将树分解为独立轨迹进行策略优化
  19. 核心思路:每条根到叶路径构成一条轨迹,M个叶节点产生M条轨迹。策略更新中引入访问计数 \(C(x_k)\) 归一化,防止频繁访问的共享节点被过度优化。宏动作优势均匀分配到该轮所有token
  20. 为什么有效:消融实验证实不做访问计数归一化会导致熵不受控增长和策略崩溃

  21. 异步执行与KV缓存优化:

  22. 做什么:降低树搜索的计算开销
  23. 核心思路:助手模型生成、用户模型交互、critic值估计三个阶段完全异步执行;共享前缀复用KV缓存,树结构天然适合前缀共享——同一父节点的所有子节点共享相同的对话历史前缀
  24. 为什么有效:在1.7B模型上TreePO解码速度达到2,500 tokens/sec/GPU;ATPO虽然rollout阶段占比更高(45% vs 25%),但产出更高质量的训练数据,总训练时间反而最短

实验设置

环境

  • 用户模拟器:Qwen3-8B实现,严格根据原子事实回答问题,GPT-4o验证指令遵循准确率100%,幻觉率仅1.2%
  • 助手代理:需从选项中选择正确答案,可迭代查询用户模拟器获取更多信息
  • 奖励函数:仅基于最终答案正确性——正确+3,错误0,格式无效-1

数据集

  • MedicalExam:150样本,来自5个来源(MedQA/MedMCQA/MMLU/SelfExam/QMAX)
  • MedQA:1,268样本,来自MEDIQ测试集
  • MedMCQA:536样本,从MedMCQA验证集构建
  • 训练数据14,256样本(66% MEDIQ + 34% MedMCQA)

基线

  • Zero-shot:Direct单轮 / MEDIQ多轮提示
  • SFT:标准SFT / 动态微调DFT(Gemini-2.5-Pro自我对弈生成1,269条对话)
  • SFT+RL:PPO (MDP) / PPO (H-MDP) / GRPO / TreePO

关键超参数

  • 策略学习率 \(1 \times 10^{-6}\),critic学习率 \(1 \times 10^{-5}\)
  • KL惩罚 \(\beta=0.01\),折扣因子 \(\gamma=1\)
  • GRPO组大小32;ATPO扩展大小 \(N=4\),总扩展预算128
  • ATPO (\(U_1\)): \(\tau=0.5\);ATPO (\(U_1+U_2\)): \(\alpha=0.3\), \(\tau=1.5\)

实验结果

主要结果(Table 1)

模型 方法 MedicalExam MedQA MedMCQA
Qwen3-8B GRPO 60.93 57.92 51.12
Qwen3-8B TreePO 65.33 61.81 54.74
Qwen3-8B ATPO (\(U_1+U_2\)) 65.87 64.07 53.66
GPT-4o MEDIQ 64.00 63.15 53.03
  • ATPO (\(U_1+U_2\)) 在8B规模上MedQA超越GPT-4o +0.92%
  • 相比TreePO,ATPO在MedQA上绝对提升:1.7B +0.82%,4B +1.73%,8B +2.26%
  • MEDIQ提示策略反而比Direct单轮更差,与原论文发现一致
  • SFT(含从GPT-4o/Gemini蒸馏)仅提供有限准确率增益,RL训练不可或缺

采样效率

  • Qwen3-4B在MedQA上,ATPO (\(U_1+U_2\)) 仅用TreePO约55%的训练轮次即达~52.7%准确率
  • ATPO达到PPO最佳性能的时间最短(2.22小时 vs PPO 3.02小时 vs GRPO 4.86小时)

消融分析

  • 不确定性度量\(U_1+U_2\) 产生高方差样本回报(与GRPO相当),critic值损失显著低于PPO;单独 \(U_1\) 导致探索集中在浅层(3-4层),叠加 \(U_2\) 实现更深更均匀的覆盖
  • 访问计数归一化:不做归一化→熵不受控增长和策略崩溃;对值损失也做归一化→熵快速坍缩,模型退化为次优单轮策略
  • 用户模拟器泛化:将测试时模拟器从Qwen3-8B替换为Llama-3.3-70B-Instruct,性能几乎无变化,证明未过拟合特定模拟器

优点与局限

优点

  1. 不确定性引导的自适应树搜索兼顾了采样多样性(\(U_2\))和critic优化(\(U_1\)),比固定结构的TreePO更灵活
  2. 层级MDP建模+轮次级信用分配适合多轮对话的宏观决策特性
  3. KV缓存复用和异步执行使树搜索的额外计算开销可控,总训练时间反而最短
  4. 8B模型超越GPT-4o验证了方法的有效性

局限

  1. 扩展阈值τ和α为手动设定的固定超参数,不同任务/模型可能需要重新调优
  2. 宏动作优势在轮内所有token间均匀分配,未区分关键token和冗余token
  3. 用户模拟器基于预定义原子事实,与真实患者的自由表达存在差距
  4. 仅在MCQ格式的医学数据集上验证,未涉及开放式诊断场景

个人思考

  1. 不确定性度量的通用性:Bellman误差+Q值方差的复合度量可以推广到其他需要长horizon决策的多轮交互场景(如工具使用、多轮code generation),核心是在"哪里值得花更多计算资源探索"这个问题上提供了一个量化标准
  2. 树搜索与MCTS的关系:ATPO的树搜索与AlphaGo的MCTS有相似之处(都是选择性扩展),但ATPO的不确定性度量基于值函数而非UCB,更适合连续动作空间;未来可以考虑引入UCT准则或PUCT来进一步优化节点选择
  3. 对SFT局限性的验证:即使用GPT-4o/Gemini蒸馏也无法显著提升性能,再次证实了"模仿⊊学习"——SFT学到格式但学不到决策策略,这对医学AI的训练范式有重要启示