联邦学习中的持续学习技术突破
技术背景
联邦学习是一种分布式设备在无需传输本地数据的情况下共同训练全局机器学习模型的技术,既能减少网络流量又能保护数据隐私。持续学习是指随着新数据不断更新模型的过程,关键在于避免"灾难性遗忘"——即新数据更新覆盖现有设置导致旧数据性能下降的问题。
创新方法
在某中心于自然语言处理实证方法会议(EMNLP)发表的论文中,提出了一种结合两种技术的新方法,通过改进的持续联邦学习机制提升性能。
样本选择策略
核心创新在于数据样本保留选择程序,提供两种变体:
- 非协调方法:各设备本地自主选择样本
- 协调方法:通过中心服务器协调设备间样本选择
梯度多样性优化
通过最大化梯度多样性确保样本信息多样性。梯度作为多维空间中的方向,选择梯度总和接近零的样本可实现最大多样性。该问题可表述为在存储预算N内分配系数(1或0),使梯度总和尽可能接近零。
由于这是NP完全问题,提出松弛要求:系数总和仍为N,但允许系数为分数值,转化为计算可处理问题。最终选择系数最高的N个样本。
实验结果
在对比三种先前采样策略的实验中:
- 当N=50和100时,两种方法均显著优于基线,非协调方法稍优
- 当N=20时,协调方法成为最佳性能方案
- 当N≤10时,其他方法开始超越
协调方法详解
协调方法在本地和全局梯度求和间交替进行:
- 各设备本地优化使梯度总和接近零
- 向中心服务器发送聚合梯度及计算系数(保护数据隐私)
- 服务器计算最小修改量使全局总和为零
- 将修改后的总和作为新目标返回设备
实验表明,单次迭代即可实现接近零的全局总和。最终各设备选择系数最大的N个样本。
实际应用价值
分布式设备通常能存储超过5-10个样本,本文为根据不同设备容量优化样本选择策略提供了实用指南。
图:x轴表示每设备保留样本数,y轴表示困惑度相对变化(概率分布预测效果指标)