混合先验增强表格基础模型技术解析

本文介绍Mitra表格基础模型技术,通过混合合成先验分布提升模型泛化能力。该模型基于二维注意力机制,支持上下文学习,在多项基准测试中表现优异,无需真实数据预训练即可处理分类和回归任务。

Mitra:通过混合合成先验增强表格基础模型

生成多样化的合成先验分布可产生优于特定任务基线的表格基础模型。

表格数据在医疗、金融、电子商务和科学等领域支撑着关键决策。然而,传统用于表格数据的机器学习方法(如随机森林和XGBoost)通常会产生针对单个数据集定制的模型,跨分布迁移能力有限。

受大语言模型成功的启发,表格基础模型有望改变这一现状:无需为每个任务单独训练模型,单个预训练模型只需通过适量示例进行条件化(即上下文学习)即可泛化到新任务。

作为某中心自动机器学习框架AutoGluon最新版本的一部分,推出了基于上下文学习范式训练的表格基础模型Mitra。与大语言模型在多样化文本语料上训练类似,Mitra通过精心设计的先验分布混合生成的合成数据集进行预训练。

初看可能令人惊讶的是,预训练Mitra未使用任何真实数据。但真实表格数据通常有限且异构,具有不同的特征类型、依赖关系和噪声水平。实践证明,模拟覆盖广泛数据模式的多样化合成数据集更为实用。

研究发现,这些合成先验的质量对模型泛化能力起关键作用。有效的先验往往具备以下特点:(1)在真实任务上表现良好;(2)具有多样性以防过拟合;(3)提供其他先验未包含的独特模式。

基于这些原则,构建了包含结构因果模型(结合变量间因果依赖图和描述变量值变化影响的概率方程)与流行树型方法(如梯度提升、随机森林和决策树)的混合先验。这些先验共同使Mitra能够学习鲁棒表示,并有效泛化到各种真实表格问题。

框架概述

在包括结构因果模型和树型模型的合成数据先验混合上预训练表格基础模型。每个数据集分为支持集和查询集。Mitra支持跨行和列的二维注意力以及逐行的一维注意力。推理时,模型通过上下文学习基于真实数据集的支持示例预测查询标签,无需梯度更新。

在选定的先验混合上预训练Mitra。每个合成任务包含支持集和查询集。模型通过学习关注支持集来预测查询集标签,无需梯度更新。经过数百万个此类任务,Mitra学会可泛化的推理和适应模式。该架构基于跨行和特征的二维注意力,可灵活处理不同表格大小和特征交互。

评估结果

在分类和回归任务上评估Mitra,涵盖TabRepo、TabZilla、AMLB和TabArena等主要表格基准。与TabPFNv2、TabICL等强表格基础模型以及CatBoost、RealMLP和AutoGluon 1.3最佳质量预设等数据集特定模型相比,Mitra表现出最先进的性能。

在二维正弦棋盘数据上,Mitra比TabPFNv2显示出更规则、更少碎片化的决策边界。

正如基础模型重塑计算机视觉和自然语言处理领域,Mitra为表格数据预测提供了更通用有效的方法。随着领域发展,预计将出现更丰富的先验空间和自适应混合策略。Mitra已在AutoGluon 1.4版本中开源,可供直接使用。

comments powered by Disqus
使用 Hugo 构建
主题 StackJimmy 设计