跳转至

Empowering Decision Trees via Shape Function Branching

会议: NeurIPS 2025
arXiv: 2510.19040
代码: 未提供
领域: 可解释机器学习 / 决策树
关键词: 决策树, 形状函数, 可解释性, 表格数据, 非线性分裂

一句话总结

提出 Shape Generalized Tree (SGT),在决策树每个内部节点使用可学习的轴对齐形状函数替代传统线性阈值分裂,以更紧凑的树结构捕捉非线性特征效应,同时保持可解释性。

背景与动机

  • 传统的轴对齐线性决策树每个节点只能进行简单的 \(x_d \leq \theta\) 分裂,面对非线性特征-目标关系时需要在同一特征上反复分裂,导致树变深变大
  • 树的可理解性对路径深度和叶节点数量高度敏感,深树直接损害可解释性
  • 广义加法模型(GAMs)通过非线性形状函数建模单特征贡献,可解释性良好但缺乏决策树的层级结构
  • 斜切树(Oblique trees)允许多特征线性组合分裂,但高维切割难以由人类理解
  • 需要一种方法:在单个节点内捕获非线性决策边界,同时保持轴对齐的可解释性

核心问题

如何在不牺牲可解释性的前提下,增强决策树单个节点的表达能力,使其能在更浅/更紧凑的树中达到更好的预测性能。

方法详解

Shape Generalized Tree (SGT) 定义

传统线性树在每个节点进行分裂:

\[\mathcal{D}^l = \{(\mathbf{x}_n, y_n) \in \mathcal{D} \mid \mathbf{w}^\top \mathbf{x}_n \leq \theta\}\]

SGT 将阈值分裂替换为形状函数分裂:

\[\mathcal{D}^l = \{(\mathbf{x}_n, y_n) \in \mathcal{D} \mid f_\Theta(\mathbf{w}^\top \mathbf{x}_n) \leq 0.0\}\]

其中 \(\mathbf{w}\) 为 one-hot 特征选择向量,\(f_\Theta: \text{dom}(\mathcal{X}_d) \to \mathbb{R}\) 为可学习的形状函数。由于 \(f\) 仅作用于单一特征,仍可直接可视化。

扩展变体

  • S2GT(二元形状函数树):每个节点允许对两个特征的联合形状函数进行分裂,\(f^2_\Theta(x_{d_1}, x_{d_2}) \leq 0\),可通过热力图可视化
  • SGT\(^K\)(多路分裂树):将二分裂推广到 \(K\) 路分裂,使用向量值形状函数 \(f^{(K)}: \text{dom}(\mathcal{X}_d) \to \mathbb{R}^K\),通过 argmax 确定分支。实验限制 \(K=3\)

表达能力保证

定理1:SGT 至少与相同节点数的线性树一样有表达力(线性树是SGT的特例)

定理2:对任意 \(B \in \mathbb{N}\),存在某些函数需要线性树比SGT多至少 \(B\) 个决策节点

ShapeCART 算法

采用类CART的自顶向下贪心构造,在每个节点处求解双层优化:

\[\min_{\mathbf{w}} \sum_{d=1}^{D} w_d \min_{\Theta_d} \left[\mathcal{L}(\{\mathcal{D}^l_d, \mathcal{D}^r_d\})\right]\]

其中加权不纯度定义为:

\[\mathcal{L}(\mathbf{D}) = \sum_{\mathcal{D} \in \mathbf{D}} |\mathcal{D}| \cdot \mathcal{H}(\Pi(\mathcal{D}))\]

形状函数学习分两阶段

  1. 分箱(Binning):用内部CART树将样本按单特征值映射到 \(L\) 个互斥箱,每个箱存储经验类分布 \(\pi_\ell\) 和权重 \(W_\ell\)
  2. 箱到分支映射(Bin-to-Branch):通过坐标下降求解离散优化,将每个箱分配到左/右(或 \(K\) 路)分支,最小化加权不纯度:
\[\min_{\mathbf{a}} \sum_{k} W_k \cdot \mathcal{H}(\Pi_k)\]

初始化策略:取加权K-Means聚类和内部CART树根节点分配二者中更优者。

二元候选对筛选

为避免 \(O(D^2)\) 的特征对枚举,利用单变量形状函数的分支集笛卡尔积来快速估计交互增益:

\[\delta_{(d_1, d_2)} = \min(\mathcal{L}(\mathbf{D}_{d_1}), \mathcal{L}(\mathbf{D}_{d_2})) - \mathcal{L}(\{\mathcal{D}^i \cap \mathcal{D}^j\})\]

仅保留 \(\delta\) 值最高的 \(P\) 个候选对,并添加正则化惩罚 \(\gamma\) 抑制不必要的二元分裂。

后处理

使用 Tree Alternating Optimization (TAO) 对贪心构造的树进行全局优化,逐节点重拟合形状函数并剪枝。

实验关键数据

在 26 个真实分类数据集上评估,深度 2-6:

轴对齐方法平均测试准确率(%)

方法 Depth 2 Depth 3 Depth 4 Depth 5 Depth 6 Best
CART 83.9 81.6 82.3 84.0 85.1 85.2
SERDT 83.7 80.9 81.7 83.7 85.1 85.1
SGT-C 84.5 82.5 83.7 84.9 86.2 86.3
SGT3-C 85.7 84.6 85.9 87.3 87.8 87.8
AxTAO 83.9 82.1 82.8 84.5 85.5 85.4
DPDT 84.9 83.4 83.8 85.2 86.2 86.2
SGT-T 85.0 83.5 84.6 85.9 86.8 86.8
SGT3-T 86.4 85.1 86.2 87.5 88.0 88.0

二元方法平均测试准确率(%)

方法 Depth 2 Depth 3 Depth 4 Depth 5 Depth 6 Best
BiCART 87.3 87.6 87.9 88.7 89.6 89.6
BiTAO 87.9 88.1 88.4 89.3 90.1 90.0
S2GT-C 89.3 89.1 89.5 90.5 91.3 91.4
S2GT3-T 90.2 90.6 91.2 91.7 92.4 91.9

关键发现: - SGT-C 深度2的性能常匹配或超过 CART 深度6的性能(如 eye-movements、electricity 数据集) - S2GT-C 深度2与 BiCART/BiTAO 深度6 性能相当 - 三路分裂(SGT3)在所有深度上都优于二路分裂

亮点

  • ⭐ 形状函数分裂将 GAMs 的非线性建模能力引入决策树节点,同时保持可视化和可解释性
  • ⭐ 两阶段形状函数学习(先CART分箱再坐标下降优化分支分配)高效且有理论保证(信息增益下界 ≥ CART)
  • ⭐ 二元候选对筛选启发式大幅降低计算开销,复杂度从 \(O(D^2 \cdot NC\log N)\) 降至 \(O(P \cdot NC\log N)\)
  • 在浅深度下的性能优势特别显著,直接提升了实际部署中的可解释性

局限性 / 可改进方向

  • 作为树模型,主要面向表格数据,对图像/文本等非结构化数据适用性有限
  • 形状函数的理解比简单阈值分裂需要更多认知负担
  • 缺少人类受试者研究来系统评估SGT的实际可理解性
  • 后处理依赖TAO全局优化,增加了训练复杂度

与相关工作的对比

模型类别 代表方法 节点表达力 可解释性 性能
轴对齐线性树 CART 单特征阈值 最高 基线
斜切树 TAO-Oblique 多特征线性组合 低(高维切割) 较高
二元斜切树 BiCART 两特征线性组合 中(可可视化)
SGT(本文) ShapeCART 单特征非线性函数 高(形状函数图) 最高

启发与关联

  • 形状函数 + 决策树 = 更好的可解释表达力,这一思路可推广到其他树类模型(如随机森林、梯度提升树)
  • 坐标下降在离散优化中的应用具有参考价值:先连续放松(K-Means初始化)再离散优化(坐标下降)
  • 可以尝试用神经网络参数化形状函数(类似NAMs),可能在更复杂特征关系上进一步提升

评分

  • ⭐ 新颖性: 8/10 — 将形状函数引入决策树节点的思路自然且创新,理论保证完备
  • ⭐ 实验充分度: 8/10 — 26个数据集、多深度、多变体(SGT/S2GT/SGT\(^K\))、与强基线全面比较
  • ⭐ 写作质量: 8/10 — 定义清晰、算法伪代码完整、可视化示例直观
  • ⭐ 价值: 8/10 — 对可解释AI在高风险领域的应用有直接推动作用