跳转至

AutoMixer: Checkpoint Artifacts as Automatic Data Mixers

会议: ACL 2025
arXiv: 2506.21910
代码: 无
领域: 预训练数据优化
关键词: 数据混合、检查点利用、影响函数、预训练、推理基准

一句话总结

提出 AutoMixer 框架,利用训练过程中保存的检查点模型作为"数据混合器",通过聚合多检查点的一阶影响函数近似来重新分组和加权训练数据,在八个推理基准上取得最高 1.93% 的性能提升。

研究背景与动机

  • 领域现状:语言模型预训练的效果高度依赖训练数据的组成。如何为特定目标技能找到最优的数据混合比例是预训练中的核心问题。现有方法通常通过手动设定域权重或基于启发式规则来确定数据比例。
  • 现有痛点:数据与任务之间的关系难以直接建模——这是一个"鸡与蛋"的困境:要知道哪些数据有助于特定技能的习得,模型需要先展现该技能;但要展现技能,又需要合适的训练数据。暴力搜索所有可能的数据组合在计算上不可行。
  • 核心矛盾:影响函数可以估算训练样本对模型性能的一阶贡献,但仅基于单一检查点(通常是最终模型)计算影响分数,忽略了技能在训练过程中非单调涌现的特性——某个检查点擅长的技能可能在后续训练中消退。
  • 本文目标:设计一个框架,能够自动识别任务相关的训练数据并确定最优采样权重,充分利用训练过程中产生的检查点这一被低估的资源。
  • 切入角度:观察到不同检查点在不同任务上展现出不同的峰值能力(如表1所示,25M 参数模型在不同步数达到不同任务的最佳性能),将这些检查点作为任务特定的数据采样器。
  • 核心 idea:利用代理模型(proxy model)的模拟训练运行,选择在各任务上表现最好的检查点,计算每个检查点对训练样本的影响分数,聚合后用于数据重新分组和采样权重分配。

方法详解

整体框架

AutoMixer 通过两步过程优化预训练数据:(1) 数据重新分组(Data Regrouping)——通过模拟训练运行获取任务最优检查点,计算样本影响分数,将原始数据重新分配到任务对齐的数据组中;(2) 数据混合重加权(Datamix Reweighting)——基于每组的聚合影响密度确定采样概率,指导预训练时各数据组的加载比例。

关键设计

  1. 基于检查点的任务识别(Checkpoint-Based Task Identification):在代理模型(75M 或 350M 参数)的模拟训练中,每隔一定步数保存检查点,评估所有检查点在目标基准上的性能。对每个任务选择表现最好的检查点作为该任务的"采样器"。关键洞察是不同任务的最优检查点出现在训练的不同阶段(如 HellaSwag 通常在后期达到最佳,而某些任务在早期就收敛),说明技能习得是非单调的。
  2. 多检查点影响分数聚合(Multi-Checkpoint Influence Aggregation):对选中的 k 个检查点,分别计算训练样本的影响分数。采用 DataInf 方法高效近似 Hessian 逆,仅在嵌入层和最后一层计算梯度以减少计算开销并提升区分度。影响分数通过"混合因子"(blending factor)加权聚合,混合因子 \(\alpha_j\) 按检查点步数归一化,使较晚阶段达到最优的任务(学习更慢的技能)获得更高的权重。联合影响分数为 \(\mathcal{I}_{\text{joint}}(x_i) = \sum_{j=1}^{k} \alpha_j \cdot \mathcal{I}(x_i; \theta_j)\)
  3. 影响密度驱动的采样权重(Influence Density-Driven Sampling Weights):对每个数据组计算影响密度 \(\rho_g = \frac{1}{T_g} \sum_{x_i \in g} \mathcal{I}_{\text{joint}}(x_i) \cdot s_i\),其中 \(T_g\) 是组内总 token 数,\(s_i\) 是样本 token 数。采样权重 \(w_g = \rho_g / \sum_{g'} \rho_{g'}\),确保高影响密度的数据组在预训练中获得更多采样。

损失函数 / 训练策略

  • 预训练目标:标准的因果语言建模(CLM),使用交叉熵损失。
  • 影响函数近似:采用 DataInf 方法,通过层级正则化参数 \(\lambda_l = 0.1 \times (n \cdot d_l)^{-1} \sum_{i=1}^{n} \|\nabla_{\theta_l} \ell\|_2^2\) 绕过显式 Hessian 求逆。
  • 判别层选择:仅在嵌入层和输出层计算影响分数,避免中间层的抵消效应,提升计算效率和分数区分度。
  • 训练配置:基于 Llama-3 架构的 decoder-only 模型,在 32 GPU(4 节点×8 H100)上训练,模拟运行使用 6.4B token(100K 步 × batch 8 × 4 节点 × seq 2048)。

实验关键数据

主实验

在 FineWeb-Edu 数据集上训练,使用八个常识推理基准进行零样本评估。结果为相对于均匀采样基线的准确率提升百分比。

方法 ARC-E ARC-H BoolQ PIQA SIQA HellaSwag OBQA WinoGrande 平均
350M 参数
PPL 采样 +0.35 +0.60 +0.44 +0.70 -0.10 +0.55 +0.40 +0.90 +0.66
N-gram 采样 +0.74 +1.22 +0.79 +1.03 +1.09 +0.62 +1.16 +0.85 +0.60
AutoMixer-75M -0.15 +0.12 -0.14 +0.01 -0.10 +0.05 -0.03 -0.05 -0.04
AutoMixer-350M +2.23 +0.55 +2.16 +2.05 +2.12 +2.33 +2.01 +2.14 +1.93
1.5B 参数
PPL 采样 +0.20 +0.52 +0.32 +0.40 +0.07 +0.18 +0.75 +0.68 +0.48
N-gram 采样 +0.88 +0.82 +1.02 +0.58 +0.45 +1.22 +0.54 +0.90 +0.79
AutoMixer-350M +1.26 +0.39 +1.35 +1.22 +1.38 +1.45 +1.33 +1.41 +1.22

消融实验

配置 平均提升(%) 说明
仅最终检查点 +0.7 单一检查点信息不足
全部10个检查点 +0.8 无差别聚合,效果有限
AutoMixer-350M(选择性检查点) +1.22 任务对齐的检查点选择最有效
AutoMixer-75M 代理 -0.04~-0.01 过小的代理模型无法提供有效信号
AutoMixer-350M 代理 +1.05~+1.93 匹配目标模型规模的代理最有效

关键发现

  • 代理模型规模至关重要:75M 的代理模型几乎无法提供有效的数据选择信号(平均提升为负),而 350M 代理模型带来了显著提升。代理模型与目标模型规模的对齐是性能的关键因素。
  • 选择性检查点优于全部聚合:仅使用最终检查点(+0.7%)或全部聚合(+0.8%)远不如 AutoMixer 的任务对齐选择策略(+1.22%~+1.93%)。
  • 提升随目标模型增大而递减:350M 目标模型获得 +1.93%,1.5B 获得 +1.22%,3B 获得 +1.05%。这可能因为代理模型固定为 350M,与更大目标模型的差距增加。
  • AutoMixer 在整个训练过程中保持优势:性能轨迹图显示 AutoMixer-350M 从训练初期就高于均匀采样,最终达到 56.45% vs 51.82% 的准确率。
  • 小代理模型倾向选择长句子:75M 代理模型偏好高影响分数的长文本样本,而 350M 模型能更好地区分高低影响分数的样本。

亮点与洞察

  • 检查点即信号源的新视角:将通常被丢弃或仅用于恢复训练的检查点重新利用为数据质量信号,这一观察非常巧妙且实用。
  • 多检查点聚合捕获技能涌现动态:通过在训练的不同时间点采样检查点并聚合影响分数,有效建模了技能习得的非单调过程。
  • 计算效率的工程设计:仅在嵌入层和最后层计算梯度、使用 DataInf 近似避免 Hessian 求逆,都是在保持效果的前提下大幅降低计算成本的实用设计。
  • 混合因子的设计直觉:按检查点步数分配权重,使得学习更慢的任务获得更高的数据优先级,这与直觉一致——难学的技能更需要针对性的数据支持。

局限与展望

  • 代理模型的影响分数计算仍有显著开销(约 120 小时/100 GPU 的检查点评估 + 48 小时的模拟运行),限制了在更大规模训练中的应用。
  • 仅在推理基准上验证,未探索生成任务、代码任务等其他能力维度。
  • 75M 代理模型的失败表明框架对代理模型规模敏感,如何选择最优代理规模缺乏理论指导。
  • 当前的数据重分组是一次性的,未探索动态调整(如在训练过程中迭代更新数据混合)。
  • 未讨论数据偏差问题——优化特定基准可能导致其他能力的退化。

相关工作与启发

  • vs Data Mixing Laws (Ye et al., 2024):Data Mixing Laws 通过预测模型拟合数据混合比例,AutoMixer 直接利用检查点的影响函数,不需要额外的预测模型。
  • vs TAGET (Chang et al., 2024):同一作者前作使用 n-gram 采样进行目标感知数据选择,AutoMixer 用影响函数替代 n-gram 匹配,理论基础更强但计算开销更大。
  • vs DSIR/D4 等数据选择方法:这些方法通常基于域匹配或困惑度,AutoMixer 通过多检查点影响函数提供更精细的样本级信号。

评分

  • 新颖性: ⭐⭐⭐⭐ 将检查点工件重新利用为数据混合信号的视角新颖且实用
  • 实验充分度: ⭐⭐⭐⭐ 三个模型规模、四种基线对比、详细消融和分析
  • 写作质量: ⭐⭐⭐⭐ 框架描述清晰,数学推导完整,图示直观
  • 价值: ⭐⭐⭐⭐ 为预训练数据优化提供了新范式,检查点利用的思路可推广

相关论文