PyTorch Tabular – 面向表格数据的深度学习框架
概述
众所周知,在处理表格数据时,梯度提升模型(Gradient Boosting)往往优于其他机器学习模型。深度学习在文本和图像等其他模态中展现出的惊人效果,尚未在表格数据中得到充分验证。但近期,深度学习革命已开始关注表格领域,涌现出许多专为表格数据设计的新架构和模型,其中部分模型性能甚至可与精心调优的梯度提升模型媲美。
什么是PyTorch Tabular?
PyTorch Tabular是一个框架/封装库,旨在简化表格数据的深度学习应用,使其更易于实际场景和研究使用。该库的设计核心原则包括:
- 低阻力可用性
- 轻松定制
- 可扩展且易于部署
该框架基于PyTorch和PyTorch Lightning等成熟技术构建,无需从零开始开发。它集成了最先进的深度学习模型,可直接通过pandas数据框进行训练。
核心特性
- 高级配置驱动API:支持快速使用和迭代。仅需提供pandas数据框,库自动处理归一化、标准化、分类特征编码及数据加载器准备等繁重工作
- 可扩展模型基类:通过BaseModel抽象类可轻松实现自定义模型,同时充分利用库内置工具链
- 先进网络集成:已实现Neural Oblivious Decision Ensembles(神经 oblivious 决策集成)和TabNet(注意力可解释表格学习)等前沿架构
- 训练灵活性:通过PyTorch Lightning继承训练过程的灵活性和可扩展性
安装指南
推荐先安装适配CUDA版本的PyTorch(版本>1.3),随后执行:
|
|
也可通过GitHub源码编译安装。
配置详解
需配置四大核心组件(多数参数含智能默认值):
- DataConfig:定义目标列名、分类/数值列名及数据转换规则
- ModelConfig:模型专属配置,确定训练模型及超参数设置
- TrainerConfig:训练过程配置,包括批大小、训练轮次、早停等参数(继承自PyTorch Lightning)
- OptimizerConfig:优化器及学习率调度器配置,支持标准PyTorch组件
- ExperimentConfig(可选):实验跟踪设置,当前支持Tensorboard和Weights&Biases
配置示例:
|
|
模型训练与评估
初始化模型后调用fit方法:
|
|
使用evaluate方法评估测试数据:
|
|
预测结果通过predict方法获取(分类任务返回概率值和0.5阈值最终预测):
|
|
模型持久化
支持模型保存与加载:
|
|
支持模型列表
- 类别嵌入前馈网络:集成分类列嵌入层的简单FF网络(类似fastai表格模型)
- 神经 oblivious 决策集成(ICLR 2020):在多数据集上击败调优梯度提升模型
- TabNet(谷歌研究):采用稀疏注意力机制的多步决策模型
自定义模型实现参考官方教程文档。
资源信息
- 代码仓库:GitHub开源项目
- 文档教程:ReadTheDocs在线文档
- 贡献指南:欢迎社区贡献,详情参见项目文档
相关对比
与同基于PyTorch的fastai相比,PyTorch Tabular通过模块化解耦设计及标准组件使用,显著提升代码可 hacking 性和新模型集成便利性。
参考文献
[1] Neural Oblivious Decision Ensembles for Deep Learning on Tabular Data. arXiv:1909.06312 (2019)
[2] TabNet: Attentive Interpretable Tabular Learning. arXiv:1908.07442 (2019)
后续将针对框架内已实现模型发布专题技术博客,敬请关注。