PyTorch表格深度学习框架解析

本文深入解析基于PyTorch的表格数据处理框架PyTorch Tabular,涵盖其核心设计原则、模型架构、配置方法及实战应用,帮助读者快速掌握表格数据的深度学习解决方案。

PyTorch Tabular – 面向表格数据的深度学习框架

概述

众所周知,在处理表格数据时,梯度提升模型(Gradient Boosting)往往优于其他机器学习模型。深度学习在文本和图像等其他模态中展现出的惊人效果,尚未在表格数据中得到充分验证。但近期,深度学习革命已开始关注表格领域,涌现出许多专为表格数据设计的新架构和模型,其中部分模型性能甚至可与精心调优的梯度提升模型媲美。

什么是PyTorch Tabular?

PyTorch Tabular是一个框架/封装库,旨在简化表格数据的深度学习应用,使其更易于实际场景和研究使用。该库的设计核心原则包括:

  • 低阻力可用性
  • 轻松定制
  • 可扩展且易于部署

该框架基于PyTorch和PyTorch Lightning等成熟技术构建,无需从零开始开发。它集成了最先进的深度学习模型,可直接通过pandas数据框进行训练。

核心特性

  • 高级配置驱动API:支持快速使用和迭代。仅需提供pandas数据框,库自动处理归一化、标准化、分类特征编码及数据加载器准备等繁重工作
  • 可扩展模型基类:通过BaseModel抽象类可轻松实现自定义模型,同时充分利用库内置工具链
  • 先进网络集成:已实现Neural Oblivious Decision Ensembles(神经 oblivious 决策集成)和TabNet(注意力可解释表格学习)等前沿架构
  • 训练灵活性:通过PyTorch Lightning继承训练过程的灵活性和可扩展性

安装指南

推荐先安装适配CUDA版本的PyTorch(版本>1.3),随后执行:

1
2
pip install pytorch_tabular[all]  # 完整安装(含权重和偏差实验跟踪)
pip install pytorch_tabular        # 基础安装

也可通过GitHub源码编译安装。

配置详解

需配置四大核心组件(多数参数含智能默认值):

  1. DataConfig:定义目标列名、分类/数值列名及数据转换规则
  2. ModelConfig:模型专属配置,确定训练模型及超参数设置
  3. TrainerConfig:训练过程配置,包括批大小、训练轮次、早停等参数(继承自PyTorch Lightning)
  4. OptimizerConfig:优化器及学习率调度器配置,支持标准PyTorch组件
  5. ExperimentConfig(可选):实验跟踪设置,当前支持Tensorboard和Weights&Biases

配置示例:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
data_config = DataConfig(
    target=['target'],  # 目标列(列表形式,回归支持多目标)
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    auto_lr_find=True,  # 自动学习率查找
    batch_size=1024,
    max_epochs=100,
    gpus=1,  # GPU索引(0表示CPU)
)
model_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="1024-512-512",  # 层节点数
    activation="LeakyReLU",  # 层间激活函数
    learning_rate=1e-3
)

模型训练与评估

初始化模型后调用fit方法:

1
2
3
4
5
6
7
tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)
tabular_model.fit(train=train, validation=val)

使用evaluate方法评估测试数据:

1
result = tabular_model.evaluate(test)

预测结果通过predict方法获取(分类任务返回概率值和0.5阈值最终预测):

1
pred_df = tabular_model.predict(test)

模型持久化

支持模型保存与加载:

1
2
3
tabular_model.save_model("examples/basic")
loaded_model = TabularModel.load_from_checkpoint("examples/basic")
result = loaded_model.evaluate(test)

支持模型列表

  1. 类别嵌入前馈网络:集成分类列嵌入层的简单FF网络(类似fastai表格模型)
  2. 神经 oblivious 决策集成(ICLR 2020):在多数据集上击败调优梯度提升模型
  3. TabNet(谷歌研究):采用稀疏注意力机制的多步决策模型

自定义模型实现参考官方教程文档。

资源信息

  • 代码仓库:GitHub开源项目
  • 文档教程:ReadTheDocs在线文档
  • 贡献指南:欢迎社区贡献,详情参见项目文档

相关对比

与同基于PyTorch的fastai相比,PyTorch Tabular通过模块化解耦设计及标准组件使用,显著提升代码可 hacking 性和新模型集成便利性。

参考文献

[1] Neural Oblivious Decision Ensembles for Deep Learning on Tabular Data. arXiv:1909.06312 (2019)
[2] TabNet: Attentive Interpretable Tabular Learning. arXiv:1908.07442 (2019)

后续将针对框架内已实现模型发布专题技术博客,敬请关注。

comments powered by Disqus
使用 Hugo 构建
主题 StackJimmy 设计