跳转至

Revisiting Bi-Linear State Transitions in Recurrent Neural Networks

会议: NeurIPS 2025 arXiv: 2505.21749 代码: 无 领域: 序列建模 / 循环神经网络 关键词: Bilinear RNN, State Tracking, Finite Automata, Multiplicative Interaction, Parity

一句话总结

系统性地重新审视 RNN 中的双线性状态转移(隐状态与输入的乘法交互),理论证明双线性 RNN 可模拟任意有限状态机,并展示其在去除加性项后形成了一个从对角到全结构的自然表达力层次,揭示了 Mamba 等流行线性 RNN 处于该层次最低端。

研究背景与动机

  1. 领域现状: 状态跟踪是序列决策中的基本需求(多轮对话、机器人控制、Agent-LLM),可形式化为有限自动机/正则语言的模拟。然而,许多流行序列模型(Transformer、Mamba、mLSTM)在超出训练长度时无法学习状态跟踪任务。

  2. 现有痛点:

  3. Transformer 的失败:无法在超出训练长度上学习状态跟踪,即使是大规模预训练模型和 CoT 推理也不行
  4. 线性 RNN 的受限:Mamba 等对角线性 RNN 不能学习任意长度的状态跟踪。已知条件是状态转移矩阵需要输入依赖且允许负特征值,但可学任务仍然高度受限
  5. 双线性模型的历史遗忘:虽然早期有研究(Sutskever 2011),但因三路乘法交互导致的不稳定性和优化困难而未被广泛采用

  6. 核心矛盾: 状态跟踪本质上要求隐状态参与"计算"(而不仅是"记忆"),这需要输入来"路由"隐状态的信息流。现有线性 RNN 的加性结构天然不适合此目标。

  7. 本文要解决什么: 双线性 RNN(乘法交互,无加性项)的状态跟踪能力如何?不同参数化形成怎样的表达力层次?

  8. 切入角度: 从纯乘法交互(无偏置、无加性输入项)出发,利用尺度不变性解决训练稳定性问题。

  9. 核心 idea 一句话: 去除所有加性项的纯双线性 RNN 既能模拟任意有限状态机,又因尺度不变性而允许运行时归一化保持稳定,形成了一个清晰的表达力层次。

方法详解

整体框架

核心递推公式(纯双线性,无加性项):

\[h_i^t = (h^{t-1})^\top \mathcal{W}_i x^t = \sum_{jk} \mathcal{W}_{ijk} x_k^t h_j^{t-1}\]

等价于输入依赖的状态转移矩阵:\(h^t = \mathcal{A}_x h^{t-1}\),其中 \((\mathcal{A}_x)_{ij} = \sum_k \mathcal{W}_{ijk} x_k\)

关键洞察:没有任何加性项意味着隐状态具有尺度不变性——可以在任何时步乘以常数并在后续除以该常数而不影响最终结果。

关键设计

1. 双线性模型的表达力层次

从上到下表达力递减:

模型 可模拟任务 参数量
全双线性 任意有限状态机 \(H^2 D\)
CP 分解双线性 \(R\) 增加逐步逼近 \(R(2H+D)\)
块对角双线性 块大小 \(\geq\) 状态数 \(H'^2 D \cdot B\)
\(\mathcal{R}_2\) 块对角 仅阿贝尔群(交换运算)
实数对角 仅奇偶校验 \(HD\)
正数对角(Mamba) 无状态跟踪能力

2. 全双线性 RNN 可模拟任意 FSM(Proposition 1)

证明思路:用 one-hot 编码状态,输入依赖的状态转移矩阵可编码任意转移函数 \(\delta(q, \sigma)\)

3. CP 分解降低参数量

\[\mathcal{A}_x = \mathcal{W}^{(h_1)} \text{diag}((\mathcal{W}^{(x)})^\top x) (\mathcal{W}^{(h_2)})^\top\]

参数量从 \(H^2 D\) 降至 \(R(2H+D)\)\(R\) 为分解秩。实验表明增加 \(R\) 可逐步恢复全模型的能力。

4. 复数对角双线性 RNN 仅限于阿贝尔群(Proposition 2)

\(\mathcal{P}_x = \mathcal{P}\)(共享特征基)时,所有 \(\mathcal{A}_x\) 都可交换 \(\Rightarrow\) 只能模拟交换群操作。这是一个根本性的负面结果:即使复数对角(\(2 \times 2\) 旋转块)也无法学习一般状态机。

5. 实数对角可以"平凡地"学习奇偶校验(Proposition 3)

冻结随机权重 + 仅训练线性读出层 \(\Rightarrow\) 从仅 2 个训练样本学会任意长度的奇偶校验,成功概率 \(1 - 2^{-H}\)。这完全不需要训练递推参数。

训练策略

  • 尺度不变性利用:训练时不归一化(保持梯度流),推理时归一化隐状态(防止溢出),两者不矛盾
  • 数据:长度 2-10 训练,500 测试
  • 100,000 步训练,每步 64 样本
  • 3 个学习率选最佳

实验关键数据

主实验:Modular Addition(长度泛化,OOD 长度 500)

模型 m=2 m=3 m=5 m=10 m=25 m=50
Bilinear (full) 1.00 1.00 1.00 1.00 1.00 1.00
Block-diag (size=1, 实对角) 1.00 0.00 0.00 0.10 0.00 0.02
Block-diag (size=2) 1.00 1.00 1.00 1.00 1.00 1.00
\(\mathcal{R}_2\) Block-diag 1.00 0.00 1.00 0.66 0.37 0.00
LSTM 1.00 1.00 0.98 1.00 0.00 0.02
Mamba (1/2/4层) 0.00 0.01 0.01 0.00 0.00 0.00
Transformer (1/2/4层) 0.03 0.01 0.00 0.00 0.00 0.00

Random State Machine(OOD 长度 500)

模型 m=2 m=5 m=10 m=25
Bilinear (full) 1.00 1.00 1.00 1.00
Factored Bilinear 1.00 1.00 1.00 0.50
Block-diag (size=2) 1.00 0.00 0.00 0.00
LSTM 1.00 1.00 0.00 0.00
Mamba (4层) 0.00 0.00 0.00 0.00

消融实验:无加性项 vs 有加性项

模型 Parity OOD (len 500)
纯双线性(无加性项) 1.00
含加性项(标准形式) 显著下降

关键发现

  1. 双线性模型在所有任务上一致最佳:全双线性在 modular addition、random state machine、modular arithmetic 上均达到 1.00 的 OOD 准确率
  2. 块大小为 2 是 modular addition 的充分条件:但不足以学习一般状态机
  3. Mamba 在所有 OOD 测试上完全失败:即使 4 层,modular addition 的 OOD 准确率为 0
  4. 加性项是有害的:去除偏置和加性输入贡献是学习状态跟踪的关键
  5. 冻结权重学奇偶校验:无需训练递推参数,仅 2 个样本+线性读出即可

亮点与洞察

  1. 层次结构的清晰揭示:从 full bilinear 到 factored 到 block-diagonal 到 real-diagonal 到 positive-diagonal (Mamba),形成了一条清晰的表达力退化链
  2. 负面结果的重要性\(2 \times 2\) 复数对角块只能学习阿贝尔群——这意味着不可能通过简单的"允许复数特征值"来解决一般状态跟踪
  3. "输入路由"视角:将隐状态视为"计算参与者"而非"被动记忆",乘法交互让输入决定隐状态如何变换
  4. 尺度不变性的巧妙利用:训练和推理的不一致性通过纯乘法结构完美解决——推理时归一化不影响输出

局限性/可改进方向

  1. 参数量较大:全双线性模型的三阶张量 \(\mathcal{W} \in \mathbb{R}^{H \times H \times D}\) 参数量为 \(H^2 D\),CP 分解虽然缓解但可能损失表达力
  2. 不能高效并行训练:三路乘法交互使得标准的并行扫描算法不可直接应用
  3. 仅在 toy tasks 上验证:缺乏语言建模等大规模实验
  4. 与实际 LLM 训练的差距:隐藏维度 256、训练长度 2-10 远小于实际应用规模
  5. 分解模型在大状态空间下退化:factored bilinear 在 \(m=50\) 时准确率降至 0.95

相关工作与启发

  • Mamba/GLA/mLSTM: 本文揭示其处于双线性层次的最低端,无状态跟踪能力
  • DeltaNet/DeltaProduct: 通过 Householder 结构增加表达力,正交但互补的路径
  • Grazzi et al. (ICLR 2025): 负特征值的重要性,本文从双线性视角提供了更完整的理论
  • Observable Operator Models: 双线性 RNN 可视为其连续放松版本
  • 对线性 RNN 研究的启示: 表达力的核心瓶颈不在特征值范围,而在乘法 vs 加法交互的本质差异

评分

⭐⭐⭐⭐

理论分析深入透彻,层次结构的揭示极有洞察力。但缺乏大规模实验和高效并行化方案,距离实际应用有一定距离。作为分析性工作非常出色。