跳转至

WARP: 权重空间线性循环神经网络

会议: ICLR 2026
arXiv: 2506.01153
领域: 时间序列
关键词: 权重空间学习, 线性RNN, 自适应预测, 动力系统重建, 无梯度适应

一句话总结

提出 WARP(Weight-space Adaptive Recurrent Prediction),将线性 RNN 的隐状态显式参数化为辅助 MLP 的权重和偏置,利用输入差分驱动线性递推来更新权重,结合非线性解码实现高效序列建模,在分类、预测和动力系统重建等任务上达到 SOTA。

研究背景与动机

深度序列模型面临两大根本性限制:

泛化能力不足:无法在训练分布之外可靠工作,需要梯度下降进行适应

难以注入领域先验:前向传播过程中无法融入物理约束等领域知识

同时,两大新兴范式各有优势但尚未结合:

范式 优势 局限
权重空间学习 将神经网络权重作为数据点处理 仅用于输入/输出,未作为中间表征
线性 RNN (S4, Mamba) 硬件高效、可并行化训练 表达能力受限,信息压缩不足

核心洞察:线性 RNN 缺乏非线性导致表达力不足,但将非线性重新引入又牺牲了训练效率。WARP 通过将隐状态定义为 MLP 权重,在保持线性递推效率的同时引入解码时的非线性。

方法详解

整体框架

WARP 的核心递推关系和解码过程:

\[\theta_t = A\theta_{t-1} + B\Delta\mathbf{x}_t, \quad \mathbf{y}_t = \text{MLP}_{\theta_t}(\tau)\]

其中: - \(\theta_t \in \mathbb{R}^{D_\theta}\) 是辅助 MLP("根网络")的展平权重 - \(\Delta\mathbf{x}_t = \mathbf{x}_t - \mathbf{x}_{t-1}\)输入差分(受大脑信号处理启发) - \(A \in \mathbb{R}^{D_\theta \times D_\theta}\) 是状态转移矩阵 - \(B \in \mathbb{R}^{D_\theta \times D_x}\) 是输入转移矩阵 - \(\tau\) 是坐标系统(归一化像素位置、时间步等)

关键设计 1:自解码机制

\(\theta_t\) 同时扮演隐状态解码器参数两个角色——自己解码自己。这大幅节省了参数量,因为不需要额外的解码器网络。

关键设计 2:输入差分驱动

使用 \(\Delta\mathbf{x}_t\) 而非 \(\mathbf{x}_t\) 驱动递推: - 当输入变化缓慢时,权重更新成比例地减小 - 学习将输入差分转化为网络更新——本质上是无梯度的持续适应

关键设计 3:初始化策略

  • \(A\) 初始化为单位矩阵 \(I\):模拟残差连接,促进梯度流动
  • \(B\) 初始化为零矩阵 \(\mathbf{0}\):确保训练早期 \(\theta_t\) 不发散
  • \(\theta_0 = \phi(\mathbf{x}_0)\):通过超网络 \(\phi\) 从首个观测生成初始权重

训练与推理

训练模式: - 卷积模式:展开线性递推为卷积核 \(K\),实现并行训练 - 循环模式:区分自回归(AR)和非 AR 两种设置

损失函数

\[\mathcal{L}_{\text{MSE}} = \frac{1}{T}\sum_{t=0}^{T-1}\|\mathbf{y}_t - \hat{\mathbf{y}}_t\|_2^2\]

概率预测时使用负对数似然 NLL,分类使用交叉熵 CCE。

物理先验注入(WARP-Phys)

通过替换根网络的前向传播为物理公式(如 \(\tau \mapsto \sin(2\pi\tau + \hat{\varphi})\)),实现领域知识注入,在动力系统重建上性能提升超 10 倍。

实验关键数据

图像补全(MNIST, L=300 上下文像素)

模型 MSE ↓ BPD ↓
GRU 0.054 0.573
LSTM 0.057 0.611
S4 0.049 0.520
WARP 0.042 0.516

交通流预测(PEMS08)

模型 MAE ↓ RMSE ↓
STIDGCN (GNN-SOTA) 13.45 23.28
D2STGNN 14.35 24.18
WARP 6.59 10.10

WARP 在不使用图结构的情况下,MAE 降低超过 50%,大幅超越使用空间信息的 GNN 模型。

动力系统重建

数据集 GRU MSE LSTM MSE Transformer MSE WARP MSE WARP-Phys MSE
MSD 1.43 1.46 0.34 0.94 0.03
MSD-Zero 0.55 0.57 0.48 0.32 0.04
LV 5.83 6.18 11.27 4.72
SINE* 4.90 9.48 1728 2.77 0.62

WARP-Phys 在 MSD 上比 WARP 提升超过 30 倍(0.94 → 0.03)。

多变量时间序列分类(6 个 UEA 数据集)

WARP 在 6 个方法中 4 个数据集进入前三名,包括在 SCP2 和 Heartbeat 上达到 SOTA,在极长序列(EigenWorms, 17984 步)上表现出色。

亮点与洞察

  1. 范式级创新:首次将权重空间特征作为循环网络的中间隐状态表征,统一了权重空间学习和线性递推
  2. 大脑启发的输入差分:不处理绝对输入而处理变化量,天然支持持续学习和测试时适应
  3. 无梯度适应:快变权重 \(\theta_t\) 通过线性递推更新(非梯度下降),实现高效的运行时适应
  4. 物理先验注入的灵活性:可将任意领域知识嵌入根网络前向传播,WARP-Phys 性能提升 10 倍以上
  5. 惊人的 PEMS08 结果:不使用图结构却将 MAE 降低 50%,挑战了 GNN 在交通预测中的主导地位

局限性

  1. 状态转移矩阵 \(A \in \mathbb{R}^{D_\theta \times D_\theta}\) 可能非常大,限制了根网络的规模
  2. 物理先验注入(WARP-Phys)需要已知的领域公式,通用性受限
  3. 输入差分假设等间隔采样,对不规则时间序列的处理未讨论
  4. 分类实验中数据集数量有限(6 个),统计显著性可进一步加强
  5. 与 Mamba、Griffin 等最新线性 RNN 的直接对比不够全面

评分 ⭐⭐⭐⭐⭐

极具创新性的范式级工作。将权重空间学习与线性递推优雅结合,在简洁的框架下实现了强大的表达能力和适应能力。PEMS08 上 50% 的 MAE 降低和 WARP-Phys 的 10x 提升都是令人印象深刻的结果。唯一的顾虑是状态转移矩阵的规模问题。

相关论文