Universal Cross-Tokenizer Distillation via Approximate Likelihood Matching¶
会议: NeurIPS 2025
arXiv: 2503.20083
代码: https://github.com/bminixhofer/tokenkit
领域: 模型压缩 / 知识蒸馏
关键词: 跨分词器蒸馏, 近似似然匹配, 分词器迁移, f-散度, LLM蒸馏
一句话总结¶
本文提出 Approximate Likelihood Matching (ALM),一种基于二值化 f-散度的原则性跨分词器蒸馏方法,首次实现了跨根本不同分词器(如子词→字节级)的有效蒸馏和纯蒸馏。
研究背景与动机¶
知识蒸馏是创建高效语言模型的重要范式,但现有蒸馏方法要求教师模型和学生模型使用相同或相似的分词器,这极大限制了可用的教师-学生模型对。当前 LLM 领域的分词器异常多样化:不同模型(GPT、Llama、Gemma、Qwen)使用不同的词表和分词函数,且近年来出现了从子词向字符/字节级分词的趋势。
核心矛盾:标准蒸馏通过 KL 散度比较教师和学生的 token 级概率分布,但这要求两者共享相同的 token 空间。当分词器不同时,教师和学生将同一文本分割为不同的 token 序列,无法直接比较。已有的跨分词器方法(ULD、MinED、DSKD)使用启发式方法融入教师信息,只能作为辅助目标与主目标(如 next-token prediction)联合使用,无法实现纯蒸馏。
本文切入角度:将问题形式化为比较对齐的 token chunk 的似然值,通过二值化 f-散度近似来避免枚举无穷多可能的字节序列结果。
方法详解¶
整体框架¶
给定文本 x,分别用教师和学生分词器分词,计算所有 token 的 next-token 概率。找到两个序列之间的对齐 chunk(编码相同文本片段的 token 子序列),然后最小化对齐 chunk 之间似然的差异。
关键设计¶
-
Chunk级概率对齐:核心思想是找到教师和学生序列中编码相同文本的 token chunk,定义 chunk 级概率为 \(p(\mathbf{x}, i:j) = p(T(\mathbf{x})_{i:j} | T(\mathbf{x})_{:i})\)。理想情况下应计算所有可能 chunk 结果的 f-散度,但由于可能的字节序列无穷多,作者转而计算二值化 f-散度——只考虑"该 chunk 出现"和"该 chunk 不出现"两种情况。这构成了真实 f-散度的上界,并保留了关键性质(当且仅当 \(p_S = p_T\) 时最小)。
-
Outcome Chunk Debiasing(结果chunk去偏):子词分词器存在分词偏差(tokenization bias),例如"Hello_Wor"隐含了后面不是"ld"的信息。通过将 chunk 概率乘以一个去偏概率(下一个 token 以预分词边界字节开头的概率)来消除这种偏差。设置阈值 γ 过滤去偏概率过低的 chunk,避免概率被"擦除"到极低值。
-
隐藏状态蒸馏:由于 chunk 级概率仅提供 \(|A_c| \times 32\) 比特的信号(相比同分词器蒸馏的 \(|T(\mathbf{x})| \times |\mathcal{V}| \times 32\) 比特),额外添加隐藏状态对齐损失来丰富信号。对齐教师和学生模型中对应位置的隐藏状态,通过学习的投影函数 proj 映射维度。
-
GradMag 损失组合:针对多目标优化中不同损失梯度量级差异大的问题,提出简单的梯度幅值归一化方法:计算每个损失对最后一层的梯度,将权重设为 \(1/\|G_W^i\|\),直接求解等梯度幅值条件,比 GradNorm 更简单且效果相当。
损失函数 / 训练策略¶
ALM 目标函数: $\(\mathcal{L}^{ALM}_{S,T}(\mathbf{x}) = \sum_{i,j,k,l \in A_c(\mathbf{x})} f(p_T^{1/\tau} \| p_S^{1/\tau}) + f(1-p_T^{1/\tau} \| 1-p_S^{1/\tau})\)$
其中 f 是 f-散度的生成函数,τ 是温度超参数。可选择纯蒸馏模式(仅 ALM 损失)或混合蒸馏模式(ALM + SFT next-token prediction 损失),使用 GradMag 自动平衡。
实验关键数据¶
主实验:分词器迁移(Use Case 1)¶
| 模型 | 迁移目标 | 方法 | Avg | MMLU | BoolQ | IFEval |
|---|---|---|---|---|---|---|
| Gemma2 2B IT | 原始 | - | 58.0 | 56.9 | 83.8 | 62.5 |
| Gemma2 2B IT | →Qwen2 | SFT | 51.6 | 49.8 | 77.7 | 54.2 |
| Gemma2 2B IT | →Qwen2 | MinED | 53.0 | 51.8 | 79.6 | 57.1 |
| Gemma2 2B IT | →Qwen2 | ALM | 55.1 | 53.6 | 82.7 | 53.2 |
| Gemma2 2B IT | →Byte | SFT | 46.5 | 43.1 | 67.9 | 51.5 |
| Gemma2 2B IT | →Byte | ALM+SFT | 51.3 | 51.0 | 80.5 | 51.9 |
| Llama3.2 3B IT | →Qwen2 | ALM | 58.6 | 61.6 | 79.0 | 76.3 |
消融实验¶
| 配置 | 关键效果 | 说明 |
|---|---|---|
| Outcome Chunk Debiasing | 显著提升性能 | 去除分词偏差,阈值 γ 进一步改善 |
| GradMag vs GradNorm | 持平或更优 | 更简单的损失平衡策略 |
| 纯ALM vs ALM+SFT | 子词迁移纯ALM更优 | 纯蒸馏更好保留原模型行为 |
| 字节迁移: ALM+SFT | 字节迁移ALM+SFT更优 | 极端分词变化下需保留SFT信号 |
Use Case 2: 大→小跨分词器蒸馏¶
| 方法 | GSM8K | MATH | Avg |
|---|---|---|---|
| 教师 (OpenMath2-Llama3.1-8B) | 88.9 | 60.2 | 74.6 |
| SFT | 67.2 | 36.2 | 51.7 |
| DSKD | 65.7 | 34.9 | 50.3 |
| ALM+SFT | 70.2 | 36.4 | 53.3 |
关键发现¶
- ALM 首次实现了子词→字节级的有效蒸馏迁移(此前方法完全无效或不如 SFT)
- 将不同模型迁移到相同分词器后可进行 token 级集成,集成后性能优于单个模型
- 在效率方面,ALM 不需要 DSKD 的交叉注意力计算,也不需要 MinED 的大型 logit 矩阵对齐
- ALM 比最佳先前方法额外缩小了 34% 的教师-学生差距
亮点与洞察¶
- 原则性方法:不同于先前启发式方法,ALM 提供了数学上有根据的蒸馏目标——二值化 f-散度是真实散度的上界且保留最优性条件
- 通用性强:同一方法适用于子词→子词、子词→字节、大→小蒸馏、超网络训练等多种场景
- 分词器迁移即自蒸馏:将分词器迁移重新定义为跨分词器自蒸馏的视角非常优雅
- 模型集成新可能:通过迁移到统一分词器实现不同模型族的 token 级集成
局限与展望¶
- 二值化 f-散度是对真实散度的粗略近似,可能损失信息
- 字节级迁移仍与原始模型有较大差距,需要额外技术(如 hourglass 架构、多字节预测)
- Outcome Chunk Debiasing 只处理了结果端的偏差,条件端偏差仍未解决
- 隐藏状态对齐需要层级对应关系的先验知识
相关工作与启发¶
本文将跨分词器蒸馏从启发式方法提升为原则性框架。与 DSKD(基于交叉注意力的 token 对齐)和 MinED(基于编辑距离的最小匹配)相比,ALM 更高效且不需要预计算步骤。"分词器迁移即自蒸馏"的洞察为创建字节级模型提供了比从头训练更经济的路径。其打开了不同模型族之间知识组合的新可能。
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ 首次原则性地解决跨根本不同分词器的蒸馏问题,视角新颖
- 实验充分度: ⭐⭐⭐⭐⭐ 三个Use Case全面验证,消融详尽,效率比较充分
- 写作质量: ⭐⭐⭐⭐⭐ 数学推导严谨,方法描述清晰,图表直观
- 价值: ⭐⭐⭐⭐⭐ 解决了LLM蒸馏的核心瓶颈问题,实际意义重大
相关论文¶
- [ICCV 2025] Cross-Architecture Distillation Made Simple with Redundancy Suppression
- [NeurIPS 2025] Learning to Factorize and Adapt: A Versatile Approach Toward Universal Spatio-Temporal Foundation Models
- [ECCV 2024] UNIC: Universal Classification Models via Multi-teacher Distillation
- [ICLR 2026] Draft-based Approximate Inference for LLMs
- [ICLR 2026] Distillation of Large Language Models via Concrete Score Matching