跳转至

AdaGen: Learning Adaptive Policy for Image Synthesis

日期: 2026-03-07
arXiv: 2603.06993
代码: LeapLabTHU/AdaGen
领域: 图像生成
关键词: adaptive generation policy, reinforcement learning, adversarial reward, iterative generative models, sample-adaptive scheduling

一句话总结

提出 AdaGen 框架,将多步生成模型中的调度参数配置(如 noise level、mask ratio、guidance scale 等)建模为 MDP,通过强化学习训练轻量级策略网络实现逐样本自适应生成策略,并设计对抗式奖励模型防止奖励过拟合,在四种生成范式、五个数据集上取得显著的性能和效率提升。

研究背景与动机

现代图像生成模型(MaskGIT、自回归模型、扩散模型、Rectified Flow)的共同原则是将复杂生成任务分解为多个可控步骤。然而,这带来了一个核心挑战:大量步骤相关参数(generation policy)需要被精心配置。例如每一步的 mask ratio、noise level、temperature、guidance scale 等。以 MaskGIT 为例,仅 T=32 步就需要配置 128 个策略参数。

现有方法依赖手动设计的调度规则(如 cosine schedule、线性衰减等),存在三个关键问题:

  1. 需要专家知识:手动设计 schedule 依赖大量试错和领域经验
  2. 静态策略:所有样本共享同一套调度函数,无法适应不同样本的个体特征
  3. 次优性能:简化的 schedule 无法捕捉生成步骤间的精细需求

AdaGen 的核心洞察是:将策略设计从人工艺术转变为数据驱动的优化问题,引入可学习的策略网络,根据当前生成状态自适应地为每个样本确定最优生成策略。

方法详解

整体框架

AdaGen 将生成策略的确定建模为马尔可夫决策过程(MDP),引入轻量级策略网络作为 RL agent,通过 PPO 算法训练。整体流程:

  1. 状态空间 \(\mathcal{S}\):由当前生成步骤 \(t\) 和中间生成结果(如部分 mask 的 token、部分去噪的图像等)组成
  2. 动作空间 \(\mathcal{A}\):对应各范式的生成策略参数(mask ratio、temperature、guidance scale、timestep 等)
  3. 状态转移 \(P\):由冻结的预训练生成模型决定(对扩散/flow 模型是确定性 ODE,对 MaskGIT/AR 是随机采样)
  4. 奖励 \(R\):仅在最后一步 \(t=T\) 给出终端奖励,评估最终生成图像的质量

策略网络以生成模型的中间特征为输入(而非原始中间结果),架构为 Conv + MLP + AdaLN(引入步骤信息),输出端使用 softplus 等激活函数约束策略参数范围。训练时使用高斯噪声增加探索性(\(\pi_\phi(a_t|s_t) = \mathcal{N}(\eta_\phi(s_t), \sigma I)\)),推理时直接使用均值。

关键设计

1. 对抗式奖励建模(Adversarial Reward Modeling)

这是论文最重要的贡献之一。作者发现两种直觉式的奖励设计都会导致"奖励过拟合":

  • 统计指标(如 FID)作为奖励:FID 是批量统计量,难以提供样本级信号;且 FID 优化成功不代表视觉质量好(实验中 FID=2.56 但图像质量差)
  • 预训练奖励模型(如 ImageReward):虽提供样本级信号,但生成图像趋向单一风格,多样性严重下降

解决方案:引入判别器式的对抗奖励模型 \(r_\psi\),与策略网络进行 minimax 博弈。策略网络最大化奖励的同时,奖励模型不断更新以更好地区分真实和生成图像,从而动态抵抗过拟合,实现保真度和多样性的良好平衡。

2. 动作平滑(Action Smoothing)

当生成步数增大时(如 T=8→T=32),动作空间扩展 4 倍,训练变得不稳定。原因分析:

  • Gaussian 噪声在每步独立添加,产生高频震荡的探索轨迹
  • 这种非结构化探索对渐进式生成过程不合理(生成过程通常从高不确定性逐渐收敛)

解决方案:对策略网络的原始输出序列施加指数移动平均(EMA)滤波器

\[a_t = \beta \cdot a_{t-1} + (1 - \beta) \cdot \tilde{a}_t\]

\(\beta=0.8\) 时效果最佳。该方法满足因果性(不依赖未来信息)和低通滤波性(抑制高频振荡),有效稳定了训练。

3. 推理时精炼(Inference-time Refinement)

训练中的辅助网络在推理时可被复用:

  • 重复采样:利用对抗奖励模型 \(r_\psi\) 对多次生成结果打分,选择最高分的输出
  • 前瞻采样(Lookahead Sampling):对随机转移的模型(如 MaskGIT),在每步采样 \(K\) 个候选状态,用价值网络 \(V_\phi\) 选择期望奖励最高的状态继续生成

4. 可控的保真度-多样性权衡

引入第二个"保真度导向"策略网络,通过用户可调参数 \(\lambda\) 线性插值两个策略的输出:

\[a_t^{\text{blend}} = (1-\lambda) \cdot a_t + \lambda \cdot a_t'\]

同时奖励信号也做相应混合,建立 \(\lambda\) 与保真度-多样性谱的显式映射。

实验关键数据

主实验

表1:ImageNet 256×256 不同范式 FID-50K 结果

模型 范式 T=4 T=8 T=10 T=16
MaskGIT-S Baseline Mask. 7.65 5.01 - 4.88
AdaGen-MaskGIT-S Mask. 4.54 (-3.11) 3.71 (-1.30) - 3.36 (-1.52)
MaskGIT-L Baseline Mask. 6.91 4.65 - 3.79
AdaGen-MaskGIT-L Mask. 3.63 (-3.28) 2.86 (-1.79) - 2.41 (-1.38)
DiT-XL Baseline Diff. 9.71 5.18 - 3.31
AdaGen-DiT-XL Diff. 5.31 (-4.40) 2.82 (-2.36) - 2.19 (-1.12)
SiT-XL Baseline Flow 9.33 4.90 - 2.99
AdaGen-SiT-XL Flow 4.25 (-5.08) 2.72 (-2.18) - 2.12 (-0.87)
VAR-d16 Baseline AR - - 3.30 -
AdaGen-VAR-d16 AR - - 2.62 (-0.68) -
VAR-d30 Baseline AR - - 1.92 -
AdaGen-VAR-d30 AR - - 1.59 (-0.33) -

表2:系统级对比(ImageNet 256×256 精选)

方法 范式 Params Steps TFLOPs FID-50K
DiT-XL Diff. 675M 250 59.6 2.27
DiT-XL (DPM-Solver) Diff. 675M 50 12.2 2.29
AdaGen-DiT-XL Diff. 688M 16 4.1 2.19
SiT-XL Flow 675M 50 12.2 2.18
AdaGen-SiT-XL Flow 688M 16 4.1 2.12
VAR-d30 AR 2.0B 10 2.0 1.92
AdaGen-VAR-d30 AR 2.0B 10 2.0 1.59
MDT Diff. 676M 250 59.6 1.79
MAGVIT-v2 Mask. 307M 64 - 1.78

消融实验

表3:核心组件消融(MaskGIT-S / DiT-XL, T=4)

Learnable? Adaptive? MaskGIT-S FID DiT-XL FID
7.65 9.71
5.40 (-2.25) 6.03 (-3.68)
4.54 (-3.11) 5.31 (-4.40)

表4:动作平滑权重 \(\beta\) 消融(MaskGIT-S, T=16)

\(\beta\) 0 0.2 0.4 0.8 0.95
FID-50K 3.97 3.61 3.47 3.36 3.70

表5:推理时精炼效果(MaskGIT-L T=32 / DiT-XL T=16)

方法 MaskGIT-L FID DiT-XL FID
无精炼 2.28 2.19
+ 重复采样 2.07 2.06
+ 重复 + 前瞻 1.94 -

关键发现

  1. 步数越少增益越大:DiT-XL 在 T=4 时 FID 提升 4.40,T=16 时提升 1.12,说明 AdaGen 在计算受限场景优势最为显著
  2. 等效约 3× 推理加速:AdaGen-DiT-XL 16 步(4.1 TFLOPs)超越原版 50 步(12.2 TFLOPs)的性能
  3. 计算开销极小:策略网络仅增加 0.07%–0.40% 的推理计算量
  4. 可学习性 > 自适应性:从 baseline 到 learnable 的提升(30-38%)远大于从 learnable 到 adaptive 的提升(12-16%),说明手工 schedule 确实严重次优
  5. 对抗奖励优于静态奖励:FID 和 ImageReward 作为奖励都会导致过拟合(质量差或多样性低),对抗奖励实现了保真度和多样性的平衡
  6. 推理时精炼有效:MaskGIT-L 从 2.28 降至 1.81(3 次重复+前瞻),进一步挖掘了辅助网络的价值

亮点与洞察

  1. 统一性极强:一个框架覆盖 MaskGIT、DiT、SiT、VAR、Stable Diffusion 五种模型和四种范式,MDP 建模简洁优雅
  2. 奖励过拟合的发现与解决:深刻揭示了直接优化 FID 或预训练奖励模型的缺陷,对抗奖励方案既优雅又有效
  3. EMA 动作平滑:从信号处理角度解决 RL 探索不稳定问题,思路巧妙且实现极简
  4. 辅助网络复用:训练时的判别器和价值网络在推理时作为质量评估器,无需额外开销
  5. 不修改生成模型参数:策略网络作为外挂模块,冻结原始生成模型,避免了 fine-tuning 带来的灾难性遗忘风险
  6. VAR FID 1.92→1.59:对已经很强的 SOTA 模型仍有显著提升,说明手工 schedule 在各个层级都存在优化空间

局限性 / 可改进方向

  1. 训练成本:需要 PPO + 判别器联合训练,对每个新的生成模型需要独立训练策略网络
  2. 仅限终端奖励:中间步骤无奖励信号,可能导致长 horizon MDP 中信用分配困难
  3. Text-to-Image 规模受限:大规模实验仅到 Stable Diffusion 1.x(1.4B),未在 SDXL、FLUX 等更大模型上验证
  4. 动作空间离散化:对 top-k 等离散策略参数的处理不够自然(论文中用连续近似 + 离散化)
  5. 对抗训练稳定性:虽然用了 PPO,但对抗奖励的 minimax 训练仍可能面临模式崩塌风险
  6. 缺少与搜索方法的公平对比:如进化搜索、数值梯度方法等自动调度方法,在相同计算预算下的对比不够充分

相关工作与启发

  • 与 RLHF 的关系:RLHF 直接 fine-tune 生成模型参数,AdaGen 冻结模型只学策略,更轻量且通用
  • 与 GAN 训练的联系:借鉴 GAN 判别器思想设计对抗奖励,但不用于训练生成器本身
  • 与 inference-time scaling 的关系:推理时精炼策略(重复采样 + 前瞻)是一种 test-time compute 的利用方式,与当前 LLM 的 inference-time scaling 趋势呼应
  • 启发:这种"外挂策略网络优化调度"的思路可以推广到视频生成、3D 生成等更复杂的多步生成场景

评分

维度 分数 (1-10) 说明
新颖性 8 MDP 建模 scheduling 问题 + 对抗奖励设计巧妙
技术深度 8 RL + GAN + 信号处理多个方向融合,分析透彻
实验充分度 9 4 种范式、5 个数据集、6 个模型、大量消融
实用性 8 即插即用、计算开销极小、性能提升显著
写作质量 8 结构清晰、表格丰富、motivation 有说服力
综合 8.2 统一且通用的框架,实验扎实,是自适应生成策略方向的代表性工作