跳转至

MetaAug: Meta-Data Augmentation for Post-Training Quantization

会议: ECCV 2024
arXiv: 2407.14726
代码: 有
领域: 模型压缩
关键词: 训练后量化, 元学习, 双层优化, 数据增强, 过拟合缓解

一句话总结

提出 MetaAug,一种基于元学习的训练后量化(PTQ)方法,通过可学习的变换网络对校准数据进行增强,并以双层优化框架同时优化变换网络和量化模型,有效缓解 PTQ 在小校准集上的过拟合问题。

研究背景与动机

PTQ 的核心困境:过拟合

深度神经网络在资源受限设备上部署时,量化是降低计算和存储开销的关键技术。量化方法分为两类: - 量化感知训练(QAT):需要大规模训练数据重训,精度好但实际中常受数据访问限制 - 训练后量化(PTQ):仅需少量校准数据(如 1024 张图像)即可量化预训练模型,更加实用

然而 PTQ 的核心问题在于:校准数据集太小,量化模型极易过拟合

现有方法的不足

已有工作尝试缓解 PTQ 的过拟合问题: - QDrop:随机丢弃量化激活值 - PD-Quant:利用全精度模型 BN 层统计量来修正激活分布 - 激活正则化:最小化全精度和量化模型中间特征的差异

但这些方法都存在同一个根本缺陷:它们全部依赖原始校准数据来训练量化模型,且在量化过程中没有验证集来监控过拟合。训练和评估使用同一组数据,过拟合几乎不可避免。

核心思路

如果能将校准数据"一分为二"——用变换后的校准数据作为训练集,原始校准数据作为验证集——就能在量化过程中检测并防止过拟合。关键挑战在于如何让变换网络生成的数据既保留原始数据的语义信息,又与原始数据足够不同以避免退化为恒等映射。

方法详解

整体框架

MetaAug 的核心是一个双层优化框架: - 内层优化:用变换后的数据 \(T(x_i)\) 训练量化模型 \(\theta_Q\) - 外层优化:用原始校准数据 \(x_i\) 验证量化模型,根据验证损失更新变换网络 \(T\)

这形成了一个元学习范式:变换网络学习如何修改数据使得量化模型在原始数据上泛化更好。

关键设计

  1. 元学习双层优化

    • 功能:联合优化变换网络 \(T\) 和量化模型 \(\theta_Q\)
    • 核心思路\(T^* = \arg\min_T \frac{1}{N}\sum_{i=1}^N \mathcal{L}_{\text{val}}(\hat{\theta}_Q, x_i^v)\) \(\text{s.t. } \hat{\theta}_Q = \arg\min_{\theta_Q} \frac{1}{N}\sum_{i=1}^N \mathcal{L}_Q(\theta_Q, T(x_i))\) 内层用变换数据训练量化模型,外层用原始数据验证并更新变换网络
    • 设计动机:模仿 MAML 的元学习思想——好的数据变换应使得在变换数据上训练的模型泛化到原始数据。通过二阶梯度实现端到端优化,利用 Facebook 的 higher 库计算高阶梯度
  2. 变换网络 \(T\)(基于 UNet)

    • 功能:将原始校准图像变换为保留语义但外观不同的增强图像
    • 核心思路:使用 UNet 作为图像到图像的变换网络,利用其编码-解码结构和跳跃连接来保留输入的精细特征信息
    • 设计动机:UNet 的残差连接天然有助于保留原始图像信息,同时编码器-解码器结构给予了足够的变换灵活性
  3. 分布保持损失(Distribution Preservation Loss)

    • 功能:确保变换后的数据在特征空间中与原始数据保持相同的分布结构
    • 核心思路:基于概率知识迁移(PKT),估计特征空间中任意两点的条件概率密度,使变换图像与原始图像共享相同的概率分布: \(\mathcal{L}_{DP}(T, S) = \frac{1}{N}\sum_{i=1}^N \text{KL}[\mathcal{P}_i \| \mathcal{P}_i^{(g)}]\) 其中 \(\mathcal{P}_{i|j} = \frac{K(f_{\theta_{FP}}(x_i), f_{\theta_{FP}}(x_j))}{\sum_{k \neq j} K(f_{\theta_{FP}}(x_k), f_{\theta_{FP}}(x_j))}\)\(K\) 为余弦相似度核
    • 设计动机:MSE 和 KL 散度仅考虑点对点的特征距离,忽略了样本间的整体结构关系。分布保持损失捕获数据集的全局分布信息,消融实验证实其优于 MSE(+0.45%)和 KL(+0.20%)

损失函数 / 训练策略

量化损失(内层):采用逐块量化,最小化全精度模型和量化模型第 \(l\) 块输出的 MSE:

\[\mathcal{L}_Q(\theta_Q, T(S)) = \frac{1}{N}\sum_{i=1}^N \|A_{FP}^l(T(x_i)) - A_Q^l(T(x_i))\|^2\]

验证损失:用 KL 散度衡量量化模型与全精度模型输出的一致性:

\[\mathcal{L}_{\text{val}}(\hat{\theta}_Q, S) = \frac{1}{N}\sum_{i=1}^N \text{KL}[\sigma(f_{\theta_{FP}}(x_i)) \| \sigma(f_{\hat{\theta}_Q}(x_i))]\]

Margin 损失(防止恒等退化)

\[\mathcal{L}_{\text{margin}}(T,S) = \frac{1}{N}\sum_{i=1}^N \max(0, \epsilon - \frac{1}{M}\|x_i - T(x_i)\|^2)\]

确保变换图像与原始图像的像素差异不小于阈值 \(\epsilon\)

总损失\(\mathcal{L}_T = \lambda_1 \mathcal{L}_{\text{val}} + \lambda_2 \mathcal{L}_{\text{margin}} + \lambda_3 \mathcal{L}_{DP}\),其中 \(\lambda_1=5\)\(\lambda_2=0.5\)\(\lambda_3=3\times 10^4\)

训练流程:对每个块,先交替更新变换网络和量化模型(500 次迭代),再用原始+变换数据共同量化该块(\(2\times 10^4\) 次迭代)。

实验关键数据

主实验:ImageNet Top-1 准确率

方法 比特宽度 (W/A) ResNet-18 ResNet-50 MobileNetV2
全精度 32/32 71.01 76.63 72.20
QDrop 4/4 69.10 75.03 67.89
PD-Quant 4/4 69.23 75.16 68.19
Genie-M 4/4 69.35 75.21 68.65
MetaAug (Ours) 4/4 69.48 75.29 68.76
QDrop 2/2 51.14 54.74 8.46
PD-Quant 2/2 53.14 57.16 13.76
Genie-M 2/2 53.71 56.71 16.25
Bit-Shrinking* 2/2 57.33 59.03 18.23
MetaAug* (Ours) 2/2 57.89 60.50 19.61

在极低比特(W2A2)下提升最显著:ResNet-50 比 Bit-Shrinking 高 1.47%,MobileNetV2 高 1.38%。

消融实验

配置 损失组合 ResNet-18 W2A2
Genie-M 基线 53.71
(a) \(\mathcal{L}_{\text{val}}\) 53.45
(b) \(\mathcal{L}_{\text{val}} + \mathcal{L}_{\text{MSE}}\) 53.64
(c) \(\mathcal{L}_{\text{val}} + \mathcal{L}_{\text{KL}}\) 53.89
(d) \(\mathcal{L}_{\text{val}} + \mathcal{L}_{\text{DP}}\) 54.09
(e) \(\mathcal{L}_{\text{val}} + \mathcal{L}_{\text{DP}} + \mathcal{L}_{\text{margin}}\) 54.22

分布保持损失 \(\mathcal{L}_{DP}\) 一致优于 MSE 和 KL;加入 margin 损失进一步提升 0.13%。

过拟合分析

方法 W/A 测试集准确率 校准集准确率 训练-测试差距
QDrop 2/2 51.14 77.53 26.39
PD-Quant 2/2 53.14 83.30 30.16
Genie-M 2/2 53.77 80.18 27.01
MetaAug 2/2 54.22 77.64 23.42

MetaAug 不仅测试准确率最高,训练-测试差距最小(23.42 vs PD-Quant 的 30.16),证实有效缓解过拟合。

关键发现

  • 与传统增强对比:MetaAug(54.22%)优于 Random Flip(53.93%)、Cutmix(54.15%)、Mixup(54.05%),且可与它们叠加使用(MetaAug+Cutmix 达 54.63%)
  • 低比特优势更明显:4/4 设定提升有限,但 2/2 设定提升显著(+0.5%~1.5%),因为低比特下过拟合更严重,元学习带来的泛化改善更有价值
  • 变换网络的可视化验证了生成图像改变外观但保留语义结构

亮点与洞察

  1. 用验证的视角看 PTQ:首次在 PTQ 中引入训练/验证分离的概念,从数据优化角度而非模型正则化角度解决过拟合
  2. 精巧的防退化设计:margin 损失防止变换网络变为恒等映射,分布保持损失防止语义信息丢失,两者缺一不可
  3. 与传统增强互补:MetaAug 可以与 Mixup、Cutmix 等组合使用获得更大提升,说明学到的变换与随机增强本质不同
  4. 实际过拟合量化:通过训练集-测试集准确率差距定量展示过拟合程度,为 PTQ 过拟合研究提供了清晰的评估基准

局限与展望

  • 变换网络仅进行光度变换,未引入几何变换(作者建议整合 Spatial Transformer)
  • UNet 作为变换网络引入了额外的训练开销(500 次迭代/块)
  • 仅在 ImageNet 分类上验证,未扩展到检测、分割等任务
  • 超参数(\(\lambda_1, \lambda_2, \lambda_3, \epsilon\))对不同架构需要不同配置(ResNet-50 用 \(\epsilon=0.3\),其他用 \(\epsilon=0.1\)

相关工作与启发

  • 与 MAML 的联系:双层优化框架源自 MAML,但目标从"学习好的初始化"转变为"学习好的数据变换"
  • 与 MetaMix/MetaQuantNet 的区别:这些工作用元学习优化量化策略本身,MetaAug 首次从数据角度使用元学习解决 PTQ 过拟合
  • 启发方向:PTQ 元学习框架可扩展到其他需要少样本校准的模型压缩场景(剪枝、蒸馏等)

评分

  • 新颖性: ⭐⭐⭐⭐ 首次用元学习双层优化从数据角度解决 PTQ 过拟合,视角独特
  • 实验充分度: ⭐⭐⭐⭐ 多架构、多比特宽度的消融和对比实验全面,过拟合分析直观
  • 写作质量: ⭐⭐⭐⭐ 公式推导清晰,动机阐述充分
  • 价值: ⭐⭐⭐⭐ 为 PTQ 研究提供了新范式,与现有方法正交可组合

相关论文