深度学习赋能表格数据的TabTransformer技术

本文详细介绍TabTransformer架构如何将Transformer技术应用于表格数据,通过对比实验证明其在监督学习和半监督学习任务中优于传统深度学习和树模型,并提供嵌入可视化和预训练方法的技术细节。

近年来,深度神经网络已成为大多数顶级人工智能系统的核心。特别是在自然语言处理(NLP)领域,基于Transformer的语言模型(如BERT)已成为应用基础。然而,在表格数据领域,基于决策树的机器学习方法通常表现更优。

TabTransformer是一种新颖的深度表格数据建模架构,适用于监督和半监督学习。该模型通过Transformer为分类变量(如月份)生成鲁棒的数据表示(嵌入),连续变量(如数值)则通过并行流处理。

借鉴NLP领域的成功方法,模型先在未标记数据上进行预训练以学习通用嵌入方案,然后在标记数据上微调以学习特定任务。在15个公开数据集上的实验表明,TabTransformer在平均AUC(接收者操作特征曲线下面积)上比最先进的深度学习方法至少提升1.0%,且与基于树的集成模型性能相当。

在半监督场景下,当标记数据稀缺时,TabTransformer通过新颖的无监督预训练程序,比最强的DNN基准平均提升2.1%的AUC。此外,该模型学习的上下文嵌入对缺失和噪声数据特征具有高度鲁棒性,并提供更好的可解释性。

表格数据处理
表格数据中,行代表样本,列包含特征(预测变量)和标签(目标变量)。TabTransformer以样本特征为输入,生成输出以近似对应标签。在半监督场景中,模型可先在无标签样本上预训练,再在带标签样本上微调。对于包含多个目标变量的大规模表格,可一次性预训练后多次微调。

架构与预训练
TabTransformer架构包含分类变量的Transformer处理流和连续变量的并行处理流。探索两种预训练方法:

  • 掩码语言建模(MLM):随机掩码部分特征,用其他特征嵌入重构被掩码特征
  • 替换标记检测(RTD):用随机值替换特征

可解释性研究
通过t-SNE可视化Transformer不同层的嵌入,发现语义相似的类在嵌入空间中形成聚类(例如客户相关特征聚集在中心,非客户特征分布在外围)。移除Transformer层后该模式消失。

实验结果

  • 监督学习:TabTransformer(82.8% AUC)与梯度提升决策树(82.9%)相当,显著优于TabNet(77.1%)和Deep VIB(80.5%)
  • 半监督学习(6个大数据集):TabTransformer-RTD/MLM在50/200/500标记数据场景下分别至少提升1.2%/2.0%/2.1%平均AUC
  • 半监督学习(9个小数据集):TabTransformer-RTD仍优于多数竞争对手

技术实现
可通过某中心SageMaker JumpStart的UI或Python SDK访问该模型,并已集成至Keras官方库。

comments powered by Disqus
使用 Hugo 构建
主题 StackJimmy 设计