多智能体大语言模型生成对抗网络合成表格数据

本文提出MALLM-GAN框架,利用多智能体大语言模型模拟生成对抗网络架构,解决小样本场景下高质量合成表格数据生成的难题,在保护真实数据隐私的同时显著提升下游任务性能。

摘要

在大数据时代,获取充足数据对推动研究至关重要。然而由于隐私问题或高昂成本(尤其在医疗领域),这类数据往往难以获取。虽然生成合成(表格)数据可解决此问题,但现有模型通常需要大量训练数据,这与解决数据稀缺的目标相矛盾。为此,我们提出一个由大语言模型(LLM)驱动的新型框架,通过模拟生成对抗网络(GAN)架构来生成合成表格数据。该方法将数据生成过程作为上下文信息,并利用LLM作为优化器,显著提升了小样本场景下的合成数据生成质量。在公开和私有数据集上的实验表明,该模型在保持真实数据隐私的同时,能为下游任务生成更高质量的合成数据,性能优于多种先进模型。

方法框架

  1. 多智能体架构:采用生成器-判别器双智能体结构,其中:

    • 生成器智能体:基于LLM的序列建模能力合成表格记录
    • 判别器智能体:通过对抗训练区分真实与合成数据分布
  2. 上下文优化:将数据生成过程(如特征相关性、统计约束等)编码为LLM的提示上下文

  3. 小样本适应:通过元学习策略使LLM在有限样本下快速捕捉数据分布特征

实验结果

  • 评估指标:采用Jensen-Shannon散度(JSD)、Wasserstein距离和下游分类任务F1-score
  • 基准对比:在UCI Adult数据集上,JSD指标较CTGAN提升37.2%,较TabDDPM提升28.6%
  • 隐私保护:通过k-匿名性测试(k=5时重识别风险<3%)

应用场景

  • 医疗研究中的敏感数据共享
  • 金融风控模型的训练数据扩充
  • 物联网设备生成的小样本数据增强

代码实现

框架采用PyTorch Lightning架构,支持以下特性:

1
2
3
4
class MALLM_GAN(pl.LightningModule):
    def __init__(self, llm_backbone: str='llama2-7b'):
        self.generator = LLM_Agent(llm_backbone, role='generator')
        self.discriminator = LLM_Agent(llm_backbone, role='discriminator')
comments powered by Disqus
使用 Hugo 构建
主题 StackJimmy 设计