跳转至

Incremental Sequence Classification with Temporal Consistency

会议: NeurIPS 2025
arXiv: 2505.16548
代码: 无
领域: Reinforcement Learning / NLP
关键词: incremental classification, temporal consistency, temporal-difference learning, LLM verification, sequence classification

一句话总结

将强化学习中时序差分(TD)学习的思想引入序列分类任务,提出 TC-\(\lambda\) 损失函数,通过要求相邻时间步的预测分布满足时序一致性条件来训练增量式序列分类器,在文本分类和 LLM 验证任务上均优于标准交叉熵方法。

研究背景与动机

领域现状:序列分类是机器学习的基础问题。传统方法只在完整序列上做预测,但许多场景要求在序列逐步展开的过程中持续更新预测——如在医疗、金融中等待完整序列代价高昂,以及近期 LLM 验证中需要尽早判断生成是否正确。

现有痛点:标准方法(Direct Cross-Entropy, DCE)直接将每个前缀的预测与最终标签对比训练。然而对于早期前缀 \(\mathbf{x}_{\leq t}\)\(t \ll T\)),最终标签是一个很"嘈杂"的训练信号——模型需要同时处理序列后续展开的不确定性和最终标签的不确定性。

核心矛盾:用远处的最终标签作为每个中间时刻的训练目标,忽略了序列的时序结构——相邻时间步之间的预测分布应该满足一致性关系。

本文目标:开发更好的损失函数来训练增量序列分类器,特别是提升前缀预测的数据效率和准确性。

切入角度:观察到一个关键恒等式:对任意校准的分类器,\(p(y|s_t) = \mathbb{E}_{p(s_{t+1}|s_t)}[p(y|s_{t+1})]\),即当前步的类别分布等于下一步类别分布的期望——这正是 TD 学习中 Bellman 方程的分类版本。

核心 idea:用下一时间步的预测分布作为当前步的"软目标",替代远处的硬标签(one-hot),实现时序一致性约束。

方法详解

整体框架

将序列分类建模为 Markov 链上的吸收概率估计问题。给定标注序列数据集,训练一个参数化分类器 \(p_\theta(y|s_t)\) 使其在每个中间状态都能准确预测最终类别。提出 TC-\(\lambda\) 损失函数族,通过调节 \(\lambda \in [0,1]\) 在纯时序一致性训练(\(\lambda=0\))和纯直接交叉熵(\(\lambda=1\))之间插值。

关键设计

  1. 时序一致性条件

    • 功能:利用 Markov 性质建立相邻时间步预测分布间的约束
    • 为什么:如果 \(p(y|s_t, \ldots, s_1) = p(y|s_t)\)(Markov 性),则必有 \(p(y|s_t) = \mathbb{E}_{p(s_{t+1}|s_t)}[p(y|s_{t+1})]\)
    • 关键意义:这意味着时间步 \(t\) 的类别分布等于时间步 \(t+1\) 的类别分布在期望意义下的值——校准的预测器不应在相邻步之间剧烈跳动
    • 注意:对文本序列,令 \(s_t = \mathbf{x}_{\leq t}\) 即可平凡满足 Markov 性
  2. TC 损失函数

    • 功能:用下一步的预测分布作为当前步的训练目标
    • 怎么做\(\ell_{\text{TC}}(\theta; \theta', \mathbf{s}, y) = H[\delta_y \| \mathbf{p}_\theta(\cdot|s_T)] + \sum_{t=1}^{T-1} H[\mathbf{p}_{\theta'}(\cdot|s_{t+1}) \| \mathbf{p}_\theta(\cdot|s_t)]\)
    • 解释:最后一步用真实标签训练(硬目标),中间步用参考模型 \(\theta'\) 在下一步的预测作为软目标
    • 优化:给定 \(\theta'\)(上一轮迭代的参数),优化 \(\theta\) 最小化 TC 损失;交替迭代
    • 区别:与 DCE 的本质区别在于将远处的硬标签 \(\delta_y\) 替换为近处的软目标 \(\mathbf{p}_{\theta'}(\cdot|s_{t+1})\)
  3. TC-\(\lambda\) 广义损失函数

    • 功能:在 TC 和 DCE 之间插值,控制"向前看"的距离
    • 怎么做\(\ell_{\text{TC-}\lambda}(\theta; \theta', \mathbf{s}, y) = \sum_{t=1}^{T} H[\mathbf{z}_t \| \mathbf{p}_\theta(\cdot|s_t)]\) 其中 \(\mathbf{z}_t = \lambda^{T-t}\delta_y + (1-\lambda)\sum_{k=1}^{T-t}\lambda^{k-1}\mathbf{p}_{\theta'}(\cdot|s_{t+k})\)
    • \(\lambda\) 的含义:几何分布的均值 \(\lambda/(1-\lambda)\) 为有效前瞻距离。\(\lambda=0\) 退化为 TC(只看下一步),\(\lambda=1\) 退化为 DCE(只用最终标签)
    • 最优 \(\lambda\):实验发现前瞻 5–50 个 token 效果最佳,对应 \(\lambda \in [0.8, 0.98]\)
  4. 理论分析(有限状态情形)

    • 功能:在表格模型中证明 TC 的收敛性、一致性和数据效率优势
    • 关键结果 (Proposition 3):在 \(T\)\(W\) 状态的 Markov 链中,TC(间接估计)的均方误差比 DCE(直接估计)小 \(W\)\(\mathbf{E}[(\hat{p}^{\text{ind}}_{mk} - p^*_{mk})^2] / \mathbf{E}[(\hat{p}^{\text{dir}}_{mk} - p^*_{mk})^2] \to 1/W\)
    • 直觉:TC 通过"数据池化"效应受益——优化相邻状态间的一致性时,可以利用从不同起点出发但经过相同中间状态的轨迹信息

损失函数 / 训练策略

  • 使用 decoder-only transformer(OPT 系列),在最后一层隐藏向量上接线性分类头
  • \(p_\theta(\cdot|\mathbf{x}_{\leq t}) = \text{softmax}(\mathbf{A}\mathbf{h}_t + \mathbf{b})\)
  • 联合优化分类头和所有 transformer 参数
  • 每个序列的损失对所有前缀取平均(而非求和),使不同长度序列贡献均等
  • 训练开销几乎与 DCE 相同(差异 < 1%)

实验关键数据

主实验:文本分类(OPT-125M)

方法 ohsumed 4tok ohsumed 16tok ohsumed all newsgroups 4tok newsgroups all imdb 4tok imdb all ag-news 4tok ag-news all
Most frequent 16.0 16.0 16.0 5.3 5.3 50.0 50.0 25.0 25.0
GPT-4o 31.5 54.0 57.5 7.5 80.4 58.0 94.3 77.4 88.3
Last token 16.7 45.0 80.6 6.5 87.9 56.6 94.7 54.4 94.8
DCE 30.5 65.5 81.1 27.7 89.0 63.5 94.4 80.0 94.8
LSTD(\(\lambda\)) 32.7 64.9 78.0 26.2 87.8 64.6 94.7 81.1 94.9
TC-\(\lambda\) (ours) 33.7 68.3 81.8 33.4 88.5 64.7 94.9 81.4 95.0

消融实验

消融维度 结论
模型规模 (125M→350M→1.3B) DCE 需要约 10× 大的模型才能匹配 TC-\(\lambda\) 的性能
\(\lambda\) 值选择 \(\lambda=1\)(即 DCE)从不最优;最优前瞻距离为 5–50 tokens
交叉熵 vs 平方损失 多类别任务中交叉熵显著优于平方损失;二分类时差异不大
增量训练 vs 仅训练完整序列 增量训练(DCE/TC-\(\lambda\))提升完整序列分类准确率
时序一致性度量 TC-\(\lambda\) 训练的模型相邻步 KL 散度显著更低

LLM 验证实验(GSM8K)

指标 DCE TC-\(\lambda\)
ROC AUC @ 8 tokens ~65% ~70%
ROC AUC @ 64 tokens ~78% ~82%
计算节省(boosted best-of-N) baseline 节省 23–33% tokens

关键发现

  • TC-\(\lambda\) 在 4/4 数据集的前缀分类和 3/4 数据集的完整序列分类上超越 DCE
  • 最重要的洞察:改善前缀预测的同时也改善了完整序列的分类准确率——时序一致性起到了正则化作用
  • 在 LLM 验证中,仅看 8 个 token 就能以约 70% 准确率区分正确/错误生成
  • Boosted best-of-N 用 TC-\(\lambda\) 验证器可以以更少的计算达到相同的答案准确率

亮点与洞察

  • 跨领域连接优美:将 RL 中 TD 学习的核心思想——用相邻状态的一致性代替最终奖励——转移到分类问题中,概念清晰且实现简单
  • 理论与实践统一:先在表格设置中严格证明数据效率优势,再在大规模 LLM 实验中验证
  • 实用价值高:对 LLM 验证/test-time scaling 有直接应用——更早判断生成好坏 = 节省计算
  • 实现极简:只需修改损失函数的目标分布,训练开销增加 < 1%

局限与展望

  • 理论分析仅在表格(有限状态)设定下严格成立,参数化模型下的保证是启发式的
  • 实验规模局限于 ≤ 1.3B 参数的模型,更大模型的效果不确定
  • LLM 验证实验使用了简单的 boosted best-of-N,未与 speculative rejection 等更先进的方法深度对比
  • \(\lambda\) 是需要调的超参数,最优值依赖于前缀长度
  • 未探索多模态序列(如视频帧级预测)的应用

相关工作与启发

  • TD Learning (Sutton, 1988):本文的直接灵感来源,将"用时序差分代替蒙特卡洛"的核心思想从标量值推广到多类别分布
  • Cobbe et al. (2021):training verifiers to solve math word problems,本文在此基础上引入更好的验证器训练方法
  • Mudgal et al. (2023):也用 TD 学习训练 LLM 验证器,但使用平方损失,不适合分类/二值结果
  • 启发:RL 中的"bootstrapping"思想(用自身预测作为目标)在监督学习中同样有效,关键是找到正确的不变性约束(此处为时序一致性)

评分

  • 新颖性: ⭐⭐⭐⭐ TD→分类的迁移新颖但自然;TC-\(\lambda\) 与 TD(\(\lambda\)) 的对应关系清晰
  • 实验充分度: ⭐⭐⭐⭐ 多数据集、多基线、多指标、含理论验证,但模型规模有限
  • 写作质量: ⭐⭐⭐⭐⭐ 理论推导清晰,与 RL 的类比直观,实验展示有说服力
  • 价值: ⭐⭐⭐⭐⭐ 方法简单通用,对 LLM verification/test-time scaling 有直接实用价值

相关论文