ATPO: Adaptive Tree Policy Optimization for Multi-Turn Medical Dialogue¶
会议: ICLR 2026
arXiv: 2603.02216
代码: https://github.com/Quark-Medical/ATPO
领域: 医学对话 / 强化学习
关键词: 多轮医疗对话, 树搜索, 策略优化, 不确定性引导, 层级MDP, 值函数估计, LLM对齐
一句话总结¶
提出 ATPO(自适应树策略优化)算法,将多轮医疗对话建模为层级马尔可夫决策过程(H-MDP),通过不确定性感知的自适应树扩展机制动态分配rollout预算,结合Bellman误差和动作值方差的复合不确定性度量来引导探索,在三个医学对话基准上以Qwen3-8B超越GPT-4o。
研究背景与动机¶
- 领域现状:医学大语言模型在单轮问答(如医学考试、疾病诊断)中已达到SOTA水平,但现实医疗对话中用户初始信息通常不完整,需要模型主动追问以收集关键信息。
- 现有痛点:
- Prompt工程方法(如MEDIQ)让模型主动提问反而降低准确率
- SFT方法仅模仿训练数据的表面模式,泛化能力差
- 轨迹级偏好优化依赖昂贵的偏好数据且对分布偏移敏感
- GRPO在长horizon任务中难以进行有效的信用分配
- PPO的值函数估计在多轮对话场景下不稳定
- 核心矛盾:多轮医疗对话本质上是长horizon序列决策问题,现有RL方法要么在信用分配上失效(GRPO将整条轨迹共享同一优势值),要么值估计不准确(PPO的单步critic在长对话中误差累积)。
- 本文要解决什么:如何在多轮医疗对话中实现高效且准确的策略优化——既能精确估计每轮对话的价值,又能高效探索对话空间。
- 切入角度:将问题建模为H-MDP,在对话轮次级别进行树搜索,用不确定性度量自适应地分配计算预算。
- 核心idea一句话:通过Bellman误差和Q值方差的复合度量识别高不确定性对话状态,选择性扩展树节点,同时提升采样多样性和critic准确性。
方法详解¶
整体框架¶
ATPO将多轮对话过程视为搜索树的扩展:初始用户查询为根节点,每个节点代表一个对话状态。在每个非终端节点处,助手模型生成N个候选宏动作(追问或最终回答),计算复合不确定性分数决定是否完全扩展(保留所有N个分支)或剪枝(仅保留1个分支)。收集的轨迹用于策略和critic模型的更新。
关键设计¶
- 层级MDP建模:
- 做什么:将多轮对话分为高层MDP和低层MDP两个层次
- 核心思路:高层MDP中,宏动作 \(y_k\) 定义为助手在第 \(k\) 轮的完整token序列;低层MDP中,微动作 \(y_{k,t}\) 对应单个token。状态 \(x_k\) 包含第 \(k\) 轮之前的交互历史和用户查询 \(q_k\)
-
为什么有效:轮次级别的信用分配比token级别更适合多轮对话——一轮对话内的所有token共享同一个宏动作优势,避免了token级别的稀疏奖励问题
-
复合不确定性度量:
- 做什么:为每个前沿节点计算不确定性分数,决定是否扩展
- 核心思路:对状态 \(x_k\) 采样N个候选宏动作 \(\{y_k^i\}_{i=1}^N\),通过一步前瞻计算动作值 \(\hat{Q}(x_k, y_k^i) = r(x_k, y_k^i) + \gamma V_\psi(x_{k+1}^i)\),然后计算:
- Bellman误差 \(U_1\):critic当前值估计与经验一步前瞻值之差的绝对值,反映值函数估计的不准确性
- Q值方差 \(U_2\):候选动作值估计的方差,反映策略的不确定性和环境的随机性
- 复合分数 \(U = \alpha U_1 + (1-\alpha) U_2\),\(\alpha=0.3\) 平衡两个信号
-
为什么有效:\(U_1\) 识别critic不准确的状态(需要更多样本改善值估计),\(U_2\) 识别策略犹豫不决的状态(需要更多探索),两者互补——单独用 \(U_1\) 会导致激进的早期探索集中在浅层,叠加 \(U_2\) 实现更深更均匀的覆盖
-
阈值驱动的剪枝策略:
- 做什么:根据不确定性阈值 \(\tau\) 决定节点扩展方式
- 核心思路:\(U(x_k) > \tau\) 时保留全部N个分支;\(U(x_k) \leq \tau\) 时随机保留1个分支(但以10%概率绕过剪枝以维持基线多样性)。扩展持续到所有对话终止或叶节点数达到预算上限
-
为什么有效:自适应分配避免了TreePO固定二叉扩展导致的节点指数增长集中在早期轮次的问题
-
值回溯与树分解:
- 做什么:从叶节点递归计算所有节点的目标值和优势
- 核心思路:叶节点目标值 \(\hat{V}\) 等于即时奖励;非叶节点取所有子节点一步TD目标的平均值。优势采用标准一步TD公式 \(\hat{A} = r + \gamma V_\psi(x_{k+1}) - V_\psi(x_k)\),使用critic估计值而非目标值(因为剪枝节点只有一个分支时目标值会导致零优势)
-
为什么有效:树结构的值回溯提供比纯Monte Carlo(GRPO)更低方差的值估计,同时比单一critic(PPO)更准确
-
PPO风格策略更新与访问计数归一化:
- 做什么:将树分解为独立轨迹进行策略优化
- 核心思路:每条根到叶路径构成一条轨迹,M个叶节点产生M条轨迹。策略更新中引入访问计数 \(C(x_k)\) 归一化,防止频繁访问的共享节点被过度优化。宏动作优势均匀分配到该轮所有token
-
为什么有效:消融实验证实不做访问计数归一化会导致熵不受控增长和策略崩溃
-
异步执行与KV缓存优化:
- 做什么:降低树搜索的计算开销
- 核心思路:助手模型生成、用户模型交互、critic值估计三个阶段完全异步执行;共享前缀复用KV缓存,树结构天然适合前缀共享——同一父节点的所有子节点共享相同的对话历史前缀
- 为什么有效:在1.7B模型上TreePO解码速度达到2,500 tokens/sec/GPU;ATPO虽然rollout阶段占比更高(45% vs 25%),但产出更高质量的训练数据,总训练时间反而最短
实验设置¶
环境¶
- 用户模拟器:Qwen3-8B实现,严格根据原子事实回答问题,GPT-4o验证指令遵循准确率100%,幻觉率仅1.2%
- 助手代理:需从选项中选择正确答案,可迭代查询用户模拟器获取更多信息
- 奖励函数:仅基于最终答案正确性——正确+3,错误0,格式无效-1
数据集¶
- MedicalExam:150样本,来自5个来源(MedQA/MedMCQA/MMLU/SelfExam/QMAX)
- MedQA:1,268样本,来自MEDIQ测试集
- MedMCQA:536样本,从MedMCQA验证集构建
- 训练数据14,256样本(66% MEDIQ + 34% MedMCQA)
基线¶
- Zero-shot:Direct单轮 / MEDIQ多轮提示
- SFT:标准SFT / 动态微调DFT(Gemini-2.5-Pro自我对弈生成1,269条对话)
- SFT+RL:PPO (MDP) / PPO (H-MDP) / GRPO / TreePO
关键超参数¶
- 策略学习率 \(1 \times 10^{-6}\),critic学习率 \(1 \times 10^{-5}\)
- KL惩罚 \(\beta=0.01\),折扣因子 \(\gamma=1\)
- GRPO组大小32;ATPO扩展大小 \(N=4\),总扩展预算128
- ATPO (\(U_1\)): \(\tau=0.5\);ATPO (\(U_1+U_2\)): \(\alpha=0.3\), \(\tau=1.5\)
实验结果¶
主要结果(Table 1)¶
| 模型 | 方法 | MedicalExam | MedQA | MedMCQA |
|---|---|---|---|---|
| Qwen3-8B | GRPO | 60.93 | 57.92 | 51.12 |
| Qwen3-8B | TreePO | 65.33 | 61.81 | 54.74 |
| Qwen3-8B | ATPO (\(U_1+U_2\)) | 65.87 | 64.07 | 53.66 |
| GPT-4o | MEDIQ | 64.00 | 63.15 | 53.03 |
- ATPO (\(U_1+U_2\)) 在8B规模上MedQA超越GPT-4o +0.92%
- 相比TreePO,ATPO在MedQA上绝对提升:1.7B +0.82%,4B +1.73%,8B +2.26%
- MEDIQ提示策略反而比Direct单轮更差,与原论文发现一致
- SFT(含从GPT-4o/Gemini蒸馏)仅提供有限准确率增益,RL训练不可或缺
采样效率¶
- Qwen3-4B在MedQA上,ATPO (\(U_1+U_2\)) 仅用TreePO约55%的训练轮次即达~52.7%准确率
- ATPO达到PPO最佳性能的时间最短(2.22小时 vs PPO 3.02小时 vs GRPO 4.86小时)
消融分析¶
- 不确定性度量:\(U_1+U_2\) 产生高方差样本回报(与GRPO相当),critic值损失显著低于PPO;单独 \(U_1\) 导致探索集中在浅层(3-4层),叠加 \(U_2\) 实现更深更均匀的覆盖
- 访问计数归一化:不做归一化→熵不受控增长和策略崩溃;对值损失也做归一化→熵快速坍缩,模型退化为次优单轮策略
- 用户模拟器泛化:将测试时模拟器从Qwen3-8B替换为Llama-3.3-70B-Instruct,性能几乎无变化,证明未过拟合特定模拟器
优点与局限¶
优点¶
- 不确定性引导的自适应树搜索兼顾了采样多样性(\(U_2\))和critic优化(\(U_1\)),比固定结构的TreePO更灵活
- 层级MDP建模+轮次级信用分配适合多轮对话的宏观决策特性
- KV缓存复用和异步执行使树搜索的额外计算开销可控,总训练时间反而最短
- 8B模型超越GPT-4o验证了方法的有效性
局限¶
- 扩展阈值τ和α为手动设定的固定超参数,不同任务/模型可能需要重新调优
- 宏动作优势在轮内所有token间均匀分配,未区分关键token和冗余token
- 用户模拟器基于预定义原子事实,与真实患者的自由表达存在差距
- 仅在MCQ格式的医学数据集上验证,未涉及开放式诊断场景
个人思考¶
- 不确定性度量的通用性:Bellman误差+Q值方差的复合度量可以推广到其他需要长horizon决策的多轮交互场景(如工具使用、多轮code generation),核心是在"哪里值得花更多计算资源探索"这个问题上提供了一个量化标准
- 树搜索与MCTS的关系:ATPO的树搜索与AlphaGo的MCTS有相似之处(都是选择性扩展),但ATPO的不确定性度量基于值函数而非UCB,更适合连续动作空间;未来可以考虑引入UCT准则或PUCT来进一步优化节点选择
- 对SFT局限性的验证:即使用GPT-4o/Gemini蒸馏也无法显著提升性能,再次证实了"模仿⊊学习"——SFT学到格式但学不到决策策略,这对医学AI的训练范式有重要启示