近年来,深度神经网络已成为大多数顶级人工智能系统的核心。特别是在自然语言处理(NLP)领域,基于Transformer的语言模型(如BERT)已成为应用基础。然而,在表格数据领域,基于决策树的机器学习方法通常表现更优。
TabTransformer是一种新颖的深度表格数据建模架构,适用于监督和半监督学习。该模型通过Transformer为分类变量(如月份)生成鲁棒的数据表示(嵌入),连续变量(如数值)则通过并行流处理。
借鉴NLP领域的成功方法,模型先在未标记数据上进行预训练以学习通用嵌入方案,然后在标记数据上微调以学习特定任务。在15个公开数据集上的实验表明,TabTransformer在平均AUC(接收者操作特征曲线下面积)上比最先进的深度学习方法至少提升1.0%,且与基于树的集成模型性能相当。
在半监督场景下,当标记数据稀缺时,TabTransformer通过新颖的无监督预训练程序,比最强的DNN基准平均提升2.1%的AUC。此外,该模型学习的上下文嵌入对缺失和噪声数据特征具有高度鲁棒性,并提供更好的可解释性。
表格数据处理
表格数据中,行代表样本,列包含特征(预测变量)和标签(目标变量)。TabTransformer以样本特征为输入,生成输出以近似对应标签。在半监督场景中,模型可先在无标签样本上预训练,再在带标签样本上微调。对于包含多个目标变量的大规模表格,可一次性预训练后多次微调。
架构与预训练
TabTransformer架构包含分类变量的Transformer处理流和连续变量的并行处理流。探索两种预训练方法:
- 掩码语言建模(MLM):随机掩码部分特征,用其他特征嵌入重构被掩码特征
- 替换标记检测(RTD):用随机值替换特征
可解释性研究
通过t-SNE可视化Transformer不同层的嵌入,发现语义相似的类在嵌入空间中形成聚类(例如客户相关特征聚集在中心,非客户特征分布在外围)。移除Transformer层后该模式消失。
实验结果
- 监督学习:TabTransformer(82.8% AUC)与梯度提升决策树(82.9%)相当,显著优于TabNet(77.1%)和Deep VIB(80.5%)
- 半监督学习(6个大数据集):TabTransformer-RTD/MLM在50/200/500标记数据场景下分别至少提升1.2%/2.0%/2.1%平均AUC
- 半监督学习(9个小数据集):TabTransformer-RTD仍优于多数竞争对手
技术实现
可通过某中心SageMaker JumpStart的UI或Python SDK访问该模型,并已集成至Keras官方库。