Diffusion Transformers as Open-World Spatiotemporal Foundation Models¶
会议: NEURIPS2025
arXiv: 2411.12164
代码: tsinghua-fib-lab/UrbanDiT
领域: time_series
关键词: Diffusion Transformer, Spatiotemporal Foundation Model, Urban Computing, Prompt Learning, Zero-shot
一句话总结¶
提出 UrbanDiT,首个基于 Diffusion Transformer 的开放世界城市时空基础模型,通过统一的 prompt learning 框架整合异构数据类型(grid/graph)和多种任务(预测/插值/外推/填补),在多城市多场景下实现 SOTA 性能并展现强大的 zero-shot 泛化能力。
背景与动机¶
城市环境中的时空动态来源于多样化的人类活动,体现为不同类型的数据:grid-based 数据(如人群流量)和 graph-based 数据(如道路网络交通速度)。现有的城市时空模型存在明显局限:
- UniST:仅支持 grid-based 数据,只做预测任务
- UrbanGPT:基于 LLM,仅支持 grid-based 数据,任务单一
- GPD:仅处理 graph-based 交通数据,不支持多数据源和 zero-shot
- CityGPT:基于 LLM 处理语言形式的城市数据,不支持多数据源和 zero-shot
核心痛点在于:缺乏一个能同时处理异构数据类型、支持多种时空任务、并在开放世界场景下泛化的统一基础模型。Diffusion Transformer(如 Sora)结合了扩散模型的生成能力和 Transformer 的可扩展性,为解决这一问题提供了理想的骨架。
核心问题¶
能否构建一个类似 NLP/CV 领域的基础模型,学习通用时空模式,成为城市时空应用的通用模型?具体需解决三个子问题:
- 如何统一 grid-based 和 graph-based 异构数据格式?
- 如何用单一模型支持预测、插值、外推、填补等多种任务?
- 如何在未见过的城市/数据集上实现 zero-shot 泛化?
方法详解¶
数据统一:异构时空数据转序列¶
将三维结构(2D 空间 + 1D 时间)的数据统一为一维序列格式:
- 时间维度:采用 temporal patching(类似 PatchTST),将时间序列切分为 patch
- Grid-based 数据:使用 2D 空间 patching(类似 ViT),将 \(H \times W \times T\) 重组为序列 \(L = \frac{H \times W \times T}{p_s \times p_s \times p_t}\)
- Graph-based 数据:使用 GCN 处理每个节点,再与时间维度整合为一维序列
任务统一:基于 Mask 的多任务框架¶
将所有任务统一为「重建被 mask 部分」的范式,通过不同 mask 策略实现:
| 任务 | Mask 策略 |
|---|---|
| Forward Prediction | mask 未来时间步 |
| Backward Prediction | mask 过去时间步 |
| Temporal Interpolation | mask 序列中特定时间点 |
| Spatial Extrapolation | mask 观测区域外的空间位置 |
| Spatio-Temporal Imputation | 随机 mask 时空位置 |
输入表示为:\(X^t = X^t \ast (1-M) + X^0 \ast M\),其中 \(M\) 控制不同任务的条件信息。
Spatio-Temporal Transformer Block¶
模型由多个时空 Transformer block 堆叠而成,每个 block 包含独立的 temporal attention 和 spatial attention。分离设计避免了联合注意力的二次复杂度问题。
Unified Prompt Learning(核心创新)¶
Data-Driven Prompt:使用三组 memory pool(key-value store),分别捕获:
- 时域模式:\((K_t, V_t)\),基于输入的时域特征检索
- 频域模式:\((K_f, V_f)\),基于输入的频域特征检索
- 空间模式:\((K_s, V_s)\),基于输入的空间特征检索
通过 softmax 注意力从 memory pool 中检索最相关的 prompt,类似 RAG 的思想,帮助模型区分不同数据集的分布差异。
Task-Specific Prompt:从 mask map 生成任务特定的 prompt \(P_m = \text{Attention}(\text{Flatten}(M))\),让模型感知当前执行的任务类型。
所有 prompt 与输入序列拼接后送入 Transformer,利用 Transformer 处理变长序列的能力实现灵活输入。
训练策略¶
- 每次迭代随机选择一个数据集和一个任务进行梯度下降
- 采用 InstaFlow 的 rectified flow 训练方法(ODE 框架,直线化噪声-数据轨迹),提升生成效率
- 训练时 diffusion steps 设为 500,推理时仅需 20 步(25 倍加速)
实验关键数据¶
数据集¶
涵盖 5 个城市(纽约、北京、上海、南京等)的多领域数据集:出租车需求、蜂窝网络流量、人群流动、交通、动态人口。数据集按 6:2:2 时间划分。
Forward Prediction(Grid-based,表 2)¶
| 模型 | TaxiBJ MAE | FlowSH MAE | TaxiNYC MAE | CrowdNJ MAE | PopBJ MAE |
|---|---|---|---|---|---|
| UniST | 14.04 | 9.10 | 5.85 | 0.119 | 0.106 |
| CSDI | 14.76 | 8.77 | 5.05 | 0.094 | 0.078 |
| UrbanDiT | 12.61 | 5.61 | 5.58 | 0.092 | 0.077 |
整体相对提升 11.3%,反向预测相对提升 30.4%。
Spatial Extrapolation(表 3,50% 空间位置被 mask)¶
UrbanDiT 大幅超越所有基线,例如 TaxiBJ 上 MAE 从 CSDI 的 36.66 降至 8.10。
Zero-shot 性能¶
在 PopSH 数据集上,UrbanDiT 的 zero-shot 性能超越了几乎所有使用训练数据的基线模型,验证了开放世界泛化能力。
消融实验¶
去除任一 prompt 均显著降低性能;去除所有 prompt 后性能最差;频域 prompt 影响最大。
可扩展性¶
UrbanDiT-L 在数据量从 0.8 增加到 1.0 时展现了最陡的性能提升斜率(0.011 vs 0.0015/0.0019),表明大模型在更多数据下仍有提升空间。
亮点¶
- 真正的统一模型:首次同时支持 grid/graph 数据 + 5 种时空任务 + 多城市多领域数据
- Prompt Learning 设计精巧:数据驱动 prompt(三域 memory pool)+ 任务特定 prompt 的组合,灵活且有效
- 强大的 zero-shot 能力:无需任何目标数据训练即超越大多数有监督基线
- 高效推理:rectified flow + 20 步推理实现 25 倍加速,兼顾质量与效率
- 可扩展性良好:大模型随数据增长持续受益,符合 foundation model 的 scaling law
局限性 / 可改进方向¶
- 数据覆盖有限:目前仅关注人类活动数据(出行、交通),未覆盖环境变量(空气污染、气候指标、微气候动态)
- 计算开销:Diffusion Transformer 的训练和推理成本仍较高,尤其在大规模城市数据上
- Prompt 设计依赖先验:三组 memory pool 的结构(时域/频域/空间)是人工设计的,可以探索自动化 prompt 发现
- 空间外推的上限:50% mask 率下的空间外推虽然大幅领先但极端稀疏场景未验证
- 仅支持数值型数据:未考虑文本、图像等多模态城市数据
与相关工作的对比¶
| 维度 | UniST | UrbanGPT | GPD | UrbanDiT |
|---|---|---|---|---|
| 模型初始化 | 从头训练 | LLM | 从头训练 | 从头训练 |
| 数据类型 | Grid | Grid | Graph | Grid + Graph |
| 多数据源 | ✓ | ✓ | ✗ | ✓ |
| 任务灵活性 | ✗ | ✗ | ✗ | ✓ |
| Zero-shot | ✓ | ✓ | ✗ | ✓ |
UrbanDiT 是唯一在所有五个维度上都满足的模型。与 CSDI(第二名基线)相比,UrbanDiT 不仅在 CSDI 擅长的扩散生成任务上全面超越,还支持 CSDI 无法处理的多数据类型和多任务。
启发与关联¶
- Prompt Learning 的通用价值:数据驱动 prompt 通过 memory pool 检索的方式,可迁移到其他需要统一异构数据的基础模型场景
- Mask-based 任务统一:将预测/插值/外推/填补统一为 mask 重建的思路,类似 MAE 在视觉中的做法,值得在更多领域推广
- Rectified Flow 加速:ODE 直线化轨迹的训练策略可显著降低扩散模型推理步数,适合计算受限的实际部署
- 城市计算 × Foundation Model:随着城市数据的持续积累,领域专用基础模型的价值将越来越大
评分¶
- 新颖性: ⭐⭐⭐⭐ — 首次将 DiT 扩展到城市时空统一建模,prompt learning 框架设计有创意
- 实验充分度: ⭐⭐⭐⭐⭐ — 5 个城市多领域数据集、5 种任务、few-shot/zero-shot、消融、可扩展性实验全面
- 写作质量: ⭐⭐⭐⭐ — 结构清晰,图表丰富,方法阐述完整
- 价值: ⭐⭐⭐⭐ — 为城市时空基础模型树立了新 benchmark,开源代码和数据集有利于后续研究