深度学习赋能表格数据的TabTransformer技术

本文介绍了TabTransformer模型如何将Transformer架构应用于表格数据,通过无监督预训练和微调技术提升分类和回归任务性能,在15个公开数据集上验证了其优越性,并详细解析了模型架构和训练方法。

近年来,深度神经网络已成为大多数顶尖人工智能系统的核心。特别是在自然语言处理领域,基于Transformer的语言模型已成为应用基础。然而在表格数据领域,基于决策树的机器学习方法通常表现更佳。

某中心的研究团队开发了TabTransformer模型,将Transformer架构从自然语言处理扩展至表格数据。该新型深度学习架构支持监督学习和半监督学习,通过Transformer为分类变量生成鲁棒的数据表示,连续变量则通过并行流处理。

TabTransformer借鉴了自然语言处理中的成功方法:先在无标签数据上进行预训练学习通用嵌入方案,然后在有标签数据上微调以学习特定任务。在15个公开数据集上的实验表明,TabTransformer在平均AUC指标上比现有表格数据深度学习方法至少提升1.0%,并与基于树的集成模型性能相当。

在半监督学习场景下,当标记数据稀缺时,TabTransformer展现出更大优势。通过新颖的无监督预训练程序,其平均AUC比最强的深度神经网络基准提高2.1%。该模型学习的情境嵌入对缺失和噪声数据特征具有高度鲁棒性,并提供更好的可解释性。

表格数据处理 TabTransformer以每个样本的特征作为输入,生成输出以最佳近似对应标签。在实际工业场景中,模型可先在无标签样本上预训练,然后在有标签样本上微调。该架构还可针对包含多个目标变量的大型表格数据进行一次性预训练,然后为不同目标变量进行多次微调。

预训练方法 研究探索了两种预训练方法:掩码语言建模和替换标记检测。在掩码语言建模中,随机选择部分特征进行掩码,使用其他特征的嵌入重建被掩码特征。在替换标记检测中,用随机值替换特征而非掩码。

可解释性研究 通过t-SNE可视化技术分析Transformer不同层生成的情境嵌入。实验显示,语义相似的类别在嵌入空间中形成聚类。例如客户相关特征聚集在中心区域,非客户相关特征分布在外围。这种聚类模式在普通多层感知机模型中未能观察到。

性能评估 在监督学习实验中,TabTransformer与最先进的梯度提升决策树模型性能相当,并显著优于之前的深度神经网络模型。

模型名称 平均AUC (%)
TabTransformer 82.8 ± 0.4
多层感知机 81.8 ± 0.4
梯度提升决策树 82.9 ± 0.4

在半监督学习实验中,当未标记数据量较大时,TabTransformer显著优于所有竞争对手。特别是在50、200和500个标记数据点场景下,平均AUC改进分别至少达1.2%、2.0%和2.1%。

目前TabTransformer已通过某中心的SageMaker JumpStart提供,可用于分类和回归任务,并已被集成到Keras官方开源库中。

comments powered by Disqus
使用 Hugo 构建
主题 StackJimmy 设计