DiffusionBlocks: Block-wise Neural Network Training via Diffusion Interpretation¶
会议: ICLR2026
arXiv: 2506.14202
代码: SakanaAI/DiffusionBlocks
领域: image_restoration
关键词: block-wise training, diffusion models, score matching, memory efficiency, residual networks
一句话总结¶
提出 DiffusionBlocks,将残差网络的逐层更新解释为连续时间扩散过程的离散化步骤,从而将网络切分为可完全独立训练的 block,在保持端到端训练性能的同时按 block 数 B 倍减少训练显存。
背景与动机¶
- 端到端反向传播需要存储所有层的中间激活值,显存随网络深度线性增长,严重制约模型规模和实际部署
- 已有的 block-wise 训练方法(如 Forward-Forward、greedy layer-wise training)依赖临时性的局部目标函数,缺乏理论保证,且基本仅在分类任务上验证,无法自然扩展到生成任务
- Score-based diffusion models 的去噪目标天然具有"各噪声级别可独立优化"的性质——这恰好为 block-wise 训练提供了缺失的理论基础
- 残差连接(ResNet、Transformer 等)的更新规则 \(\mathbf{z}_{\ell+1} = \mathbf{z}_\ell + f_{\theta_\ell}(\mathbf{z}_\ell)\) 可对应扩散过程 probability flow ODE 的 Euler 离散化
核心问题¶
如何为基于 Transformer 的网络设计一种有理论根据的 block-wise 训练框架,使得:
- 每个 block 可以完全独立训练(不需要其他 block 的梯度或激活值)
- 与端到端训练保持竞争力
- 能通用于分类和生成等多种任务/架构
方法详解¶
核心洞察:残差连接 = 扩散过程离散步¶
在 Variance Exploding (VE) 扩散框架下,给定噪声级别 \(\sigma_0 > \sigma_1 > \cdots > \sigma_T\),对 probability flow ODE 做 Euler 离散化得到:
这与残差网络的 skip connection 更新规则 \(\mathbf{z}_\ell = \mathbf{z}_{\ell-1} + f_{\theta_\ell}(\mathbf{z}_{\ell-1})\) 天然对应。
三步转换流程¶
Step 1: Block 划分 — 将 \(L\) 层网络分为 \(B\) 个 block \(\mathcal{F}_1, \ldots, \mathcal{F}_B\),每个 block 包含连续的若干层。
Step 2: 噪声范围分配 — 定义噪声分布 \(p_{\text{noise}}\)(推荐 log-normal),将 \([\sigma_{\min}, \sigma_{\max}]\) 划分为 \(B\) 个区间 \(\{[\sigma_b, \sigma_{b-1}]\}_{b=1}^B\),每个 block 负责对应范围的去噪。
Step 3: 噪声条件化改造 — 扩展每个 block 的输入为 \(\tilde{\mathbf{x}} = (\mathbf{x}, \mathbf{z}_\sigma)\),其中 \(\mathbf{z}_\sigma = \mathbf{y} + \sigma\epsilon\);加入噪声级别条件(如 AdaLN)。每个 block 独立训练预测目标 \(\mathbf{y}\)。
独立训练目标¶
每个 block \(b\) 的损失函数为:
关键在于:\(B\) 个 block 各自独立优化、无需相互通信,却能共同覆盖完整的噪声分布。
等概率划分策略 (Equi-probability Partitioning)¶
不采用均匀划分噪声区间(会在高/低噪声端浪费容量),而是按 log-normal 分布的累积概率质量等分:
这确保每个 block 处理等量的训练分布,在去噪难度最大的中间噪声级别分配更细的区间,效率更优。
推理过程¶
推理时按从高噪声到低噪声的顺序依次调用各 block 的去噪步骤;对于 diffusion model,每个去噪步只需加载一个 block,带来 \(B\) 倍推理加速。
实验关键数据¶
| 任务 / 架构 | 数据集 | 端到端基线 | DiffusionBlocks | Block 数 / 显存缩减 |
|---|---|---|---|---|
| ViT 分类 | CIFAR-100 | 60.25% Acc | 59.30% Acc | B=3 / 3× |
| DiT 图像生成 | CIFAR-10 | 32.84 FID | 30.59 FID | B=3 / 3× |
| DiT 图像生成 | ImageNet 256 | 12.09 FID | 10.63 FID | B=3 / 3× |
| Masked Diffusion 文本 | text8 | 1.56 BPC | 1.45 BPC | B=3 / 3× |
| AR Transformer 文本 | LM1B | 0.50 MAUVE | 0.71 MAUVE | B=4 / 4× |
| AR Transformer 文本 | OpenWebText | 0.85 MAUVE | 0.82 MAUVE | B=4 / 4× |
| Huginn (recurrent-depth) | LM1B | 0.49 MAUVE | 0.70 MAUVE | 消除 32 次迭代 |
- Forward-Forward 在 CIFAR-100 上仅达 7.85% 准确率,远逊于 DiffusionBlocks
- ImageNet 上 B=2 时 FID=9.90,优于端到端训练 (12.09),适度划分反而提升性能
- 等概率划分在所有层分配方案下均显著优于均匀划分(CIFAR-10 FID: 38.03 vs 43.53)
亮点¶
- 理论基础扎实:从 score matching 的噪声级别独立性出发,自然推导出 block 独立训练目标,非启发式拼凑
- 通用性极强:一套三步转换流程适用于 ViT、DiT、AR Transformer、Masked Diffusion、Recurrent-depth 共五类架构
- 等概率划分是简洁而关键的设计——让每个 block 承担等量去噪难度,无需手工调整层分配
- 多重效率收益:训练 \(B\) 倍显存缩减;diffusion model 推理 \(B\) 倍加速;recurrent-depth 模型省去 BPTT
- 部分场景超越端到端:ImageNet B=2/3 的 FID 优于不分 block 的端到端训练,说明适度专业化有正收益
局限性 / 可改进方向¶
- 实验中 ViT 分类仅在 CIFAR-100 上验证(60.25→59.30),大规模 ImageNet 分类未测试
- 推理时仍需按序调用各 block,无法并行化推理步骤
- 噪声条件化改造(AdaLN 等)增加了少量参数和工程复杂度
- B 过大时性能下降(B=6 时 FID 14.43),block 粒度有下限
- 主要面向 Transformer 类残差架构,对无残差连接的网络适用性未讨论
与相关工作的对比¶
| 方法 | 理论基础 | 任务通用性 | 连续时间 | Block 独立 |
|---|---|---|---|---|
| Forward-Forward | 对比目标 | 仅分类 | ✗ | ✓ |
| NoProp | 扩散相关 | 仅分类 | ✓(CT) 或 ✗(DT) | ✗(CT) 或 ✓(DT) |
| DiffusionBlocks | Score matching | 分类+生成 | ✓ | ✓ |
- NoProp 与自定义 CNN 架构捆绑,无法直接迁移到 Transformer;DiffusionBlocks 在 NoProp 的架构上也优于其所有变体(46.88 vs 46.06/21.31/37.57)
- 与 stage-specific diffusion models (eDiff-I 等) 的区别在于:后者是联合训练或从共享参数微调,DiffusionBlocks 各 block 完全隔离
启发与关联¶
- "残差连接 ≈ 扩散离散化步骤"的视角可进一步推广:任何具有残差结构的深层模型都可能受益于这种分块独立训练
- 等概率划分思想可迁移到其他需要"分段处理不同难度子任务"的场景(如课程学习、多尺度训练)
- 对 recurrent-depth 模型消除 BPTT 的能力值得关注——随着 universal transformer / Huginn 等模型兴起,该方法可降低其训练成本
- 结合模型并行(每个 block 放不同 GPU),可实现更激进的深度扩展
评分¶
- 新颖性: ⭐⭐⭐⭐⭐ — 将扩散独立性引入 block-wise 训练是原创性极高的理论贡献
- 实验充分度: ⭐⭐⭐⭐ — 五类架构覆盖面广,但分类任务规模偏小
- 写作质量: ⭐⭐⭐⭐⭐ — 数学推导清晰,三步流程直观易懂
- 价值: ⭐⭐⭐⭐ — 为大模型训练显存瓶颈提供了有理论保证的新范式