跳转至

SDKD: Frequency-Aligned Knowledge Distillation for Lightweight Spatiotemporal Forecasting

会议: ICCV 2025
arXiv: 2507.02939
代码: https://github.com/itsnotacie/SDKD
领域: 自动驾驶
关键词: 时空预测, 知识蒸馏, 频域解耦, 轻量化模型, CNN-Transformer混合架构

一句话总结

提出SDKD(频域解耦知识蒸馏)框架,通过频率感知的教师模型和频率对齐的蒸馏策略,将复杂时空预测模型的多尺度频域知识迁移到轻量级学生网络,在Navier-Stokes数据集上MSE最高降低81.3%。

研究背景与动机

时空预测(如交通流量、天气演变、流体动力学)是智慧城市和气候科学中的核心任务。这些任务需要同时建模高频局部模式(如突发交通拥堵、湍流涡旋)和低频全局演化(如大气压力缓慢变化、日周期交通规律)之间的复杂耦合。

近年来,CNN-Transformer混合架构通过结合局部卷积和全局注意力机制取得了SOTA精度。但这些模型面临严重的部署挑战:Transformer的自注意力机制复杂度为 \(O(N^2d)\),加上深层CNN堆叠结构,导致计算成本和内存消耗极高,无法部署在边缘设备上。

知识蒸馏(KD)是一种自然的解决方案——将复杂教师模型的知识迁移到轻量学生模型。但现有KD框架主要针对分类任务设计,直接应用于时空预测效果很差。根本原因在于时空信号具有多频段特性:简单的特征模仿无法维持高频细节和低频趋势之间的关键平衡。

具体而言,现有方法导致一个频谱纠缠困境:没有显式机制来解耦和传递不同频率的知识——要么过度平滑高频变化(如交通突增峰值被抹去),要么无法捕获缓慢演化趋势(如季节性气候变化被忽略)。而且时空信号的最优频段在空间和时间上都是非平稳的。

SDKD的核心洞察:时空信号天然具有频谱二元性——高频分量(局部快速变化)由空间局部性控制,低频分量(全局缓慢动态)由时间连续性主导。利用CNN作为高通滤波器、Transformer作为低通滤波器的天然属性,可以在教师模型的隐空间中实现频域解耦,然后将这些解耦后的频域知识作为蒸馏监督信号。

方法详解

整体框架

SDKD包含四个阶段:(1) 教师模型预训练——频率解耦架构(CNN处理高频+Transformer处理低频);(2) 轻量学生架构设计(ResNet/U-Net/MLP-Mixer);(3) 离线频率对齐蒸馏训练;(4) 在线推理部署。

关键设计

1. 频率解耦教师模型(ST-AlterNet)

  • 功能:在隐空间中显式分离高频和低频时空模式,为蒸馏提供结构化的频域先验。
  • 核心思路:采用编码器-隐状态演化-解码器架构。编码器通过多层卷积将输入 \(X\) 映射到隐空间 \(Z = \mathcal{E}(X)\)。隐状态演化模块包含两个并行的频率敏感分支:
    • 高频提取器(CNN分支)\(Z^h = \text{ConvBlock}(Z) = \sigma(\text{Conv2D}(Z; \mathbf{W}_h) + b_h)\),CNN通过局部梯度算子 \(\nabla_{x,y}\) 充当隐式高通滤波器,捕获突发变化和细粒度空间模式。
    • 低频建模器(Transformer分支)\(Z^l = \text{Softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V\),自注意力机制充当低通滤波器:\(\mathcal{F}_{\text{low}}(Z)(\omega) = \frac{1}{1+\|\omega\|^2/\lambda} \cdot \hat{Z}(\omega)\),捕获全局长程依赖和趋势变化。
    • 高低频融合:\(Z_{\text{evolved}} = \text{LayerNorm}(Z^h + Z^l)\)
  • 设计动机:根据Park等人的理论分析,卷积层天然偏向高频,Transformer天然偏向低频。串联设计使隐空间的频谱能量分布与物理规律对齐(实验验证隐空间符合Kolmogorov能量谱 \(E(\omega) \propto \omega^{-5/3}\)),为蒸馏提供可解释的频域先验。

2. 频率对齐蒸馏机制

  • 功能:从教师隐空间提取多尺度频域特征,引导学生模型同时学习高频细节和低频趋势。
  • 核心思路:定义频谱传输损失: \(\mathcal{L}_{\text{KD}} = \|\Psi_h(\mathcal{G}(X)) - \Psi_h(\mathcal{F}(X))\|_2^2 + \alpha \|\Psi_l(\mathcal{G}(X)) - \Psi_l(\mathcal{F}(X))\|_2^2\)

其中 \(\Psi_h\)\(\Psi_l\) 分别提取高频和低频特征,\(\alpha\) 平衡两者。总训练目标: \(\min_{\theta_\mathcal{G}} \|\mathcal{G}(X) - Y\|_2^2 + \lambda \mathcal{L}_{\text{KD}}\) - 设计动机:这是一种架构无关的频域对齐——不依赖学生模型的具体结构,直接在频域对齐输出特征。此外使用频谱敏感的自适应加权来缓解学生模型固有的高频过拟合和低频欠拟合偏向。

3. 多教师蒸馏(CAMKD)

  • 功能:利用多个教师模型的互补知识,通过梯度空间的多目标优化来自适应加权。
  • 核心思路:对每个教师 \(\mathcal{F}_m\) 定义蒸馏损失 \(\ell_m\),然后在每个mini-batch用"Agree to Disagree"(A2D)方法求解最优权重 \(\{\alpha_m^*\}\)\(\min_{\{\alpha_m\}} \frac{1}{2}\|\sum_{m=1}^M \alpha_m \nabla_\theta \ell_m\|^2, \quad \text{s.t.} \sum \alpha_m = 1, 0 \leq \alpha_m \leq C\)

学生更新方向:\(d = -\sum_m \alpha_m^* \nabla_\theta \ell_m\)。 - 设计动机:不同教师在不同频段和场景上各有强项。简单平均可能在教师梯度冲突时丢失信息。A2D通过梯度空间优化自动降低冲突教师的权重,充分利用互补性。

损失函数 / 训练策略

  • 预测任务损失 + 频率对齐蒸馏损失,由 \(\lambda\) 平衡。
  • 使用Adam优化器,固定学习率 \(1 \times 10^{-4}\),训练300 epochs,基于验证集早停。
  • 学生隐层维度仅为教师的20%-30%,模型总参数最多缩小5倍。

实验关键数据

主实验——Weatherbench和TaxiBJ+

方法 Weatherbench MAE↓ Weatherbench PSNR↑ Weatherbench SSIM↑ TaxiBJ+ MSE↓
T1: ST-AlterNet (教师) 0.7287 33.32 0.9493 0.0683
T2: SimVP (教师) 0.7475 32.11 0.9331 0.0697
U-Net (学生baseline) 0.9822 29.37 0.8635 0.0831
U-Net + AEKD (MSE Loss) 0.8457 31.69 0.8849 0.0715
U-Net + AVER-MKD (AB Loss) 0.8541 31.45 0.8812 0.0728
ResNet + AEKD (MSE Loss) 0.9102 30.82 0.8801 0.0784
MLP-Mixer + AEKD (MSE Loss) 1.1428 25.32 0.7893 0.0901

蒸馏后U-Net的MAE从0.9822降至0.8457(↓13.9%),PSNR从29.37提升到31.69(↑7.9%),显著缩小与教师的差距。

消融实验——蒸馏方法对比(NS数据集 + U-Net学生)

方法 MSE↓ MAE↓ SSIM↑ 整体
Baseline (S) 0.141 0.239 0.658
AVER-MKD 0.136 0.233 0.670
AEKD 0.135 0.235 0.663
CAMKD(多教师频域对齐) 0.136 0.234 0.667

在RBC数据集上,CAMKD取得MSE=0.0261、MAE=0.112、SSIM=0.728的最优结果,优于所有单教师和简单多教师方法。

关键发现

  • 频谱分析证实理论:教师模型ST-AlterNet在高频和低频区域的预测误差都较低,证明设计达到了频域解耦效果。未蒸馏U-Net在高频区域误差特别大(卷积结构导致的过平滑),蒸馏后高频和低频误差均显著下降。
  • 推理加速:NS数据集上蒸馏后U-Net推理速度是教师的2.28倍(0.0050s vs. 0.0115s)。
  • 物理可解释性:教师隐空间的频谱分布与流体力学中的Kolmogorov能量谱 \(E(\omega) \propto \omega^{-5/3}\) 一致,高频对应Navier-Stokes方程的粘性耗散项,低频对应对流项。

亮点与洞察

  • 从频域视角重新理解时空预测的蒸馏问题——跳出了"简单特征模仿"的传统思路,明确指出高频和低频知识需要分别传递。
  • 理论支撑充分:利用Neural Tangent Kernel理论和Plancherel定理证明了CNN高通/Transformer低通的属性以及频域误差分解的合理性。
  • 架构无关性:学生可以是任意轻量架构(U-Net、ResNet、MLP-Mixer),不要求与教师结构匹配。

局限与展望

  • 频域解耦依赖CNN和Transformer的先验假设,当教师模型使用其他架构(如FNO或LSTM)时适用性不明。
  • 四个benchmark规模偏小(最大35K训练样本),在更大规模真实世界数据上的效果待验证。
  • 蒸馏超参数 \(\alpha\)\(\lambda\) 需要手动调整,缺乏自适应机制。
  • 时空非平稳信号的最优频率划分应随时空位置变化,当前方法使用固定截断频率。

相关工作与启发

  • 与FitNet-R(回归任务蒸馏)和FOLK(频率掩码分类蒸馏)有方法论联系,但SDKD是首个针对时空信号设计的频域解耦蒸馏框架。
  • 与FNO(Fourier Neural Operator)的思路互补:FNO在频域显式建模PDE,SDKD通过CNN/Transformer的隐式频域特性实现更灵活的解耦。
  • 启发:能否将频域解耦蒸馏用于视频预测、点云序列预测等其他时空任务?

评分

  • 新颖性: ⭐⭐⭐⭐
  • 实验充分度: ⭐⭐⭐⭐
  • 写作质量: ⭐⭐⭐
  • 价值: ⭐⭐⭐⭐

相关论文