联邦学习中的持续学习
联邦学习是一种分布式设备利用本地收集的数据共同训练全局机器学习模型的技术,无需传输原始数据,既减少网络流量又保护数据隐私。持续学习则是指模型随着新数据不断更新,关键要避免"灾难性遗忘"——即新数据更新覆盖原有参数导致旧数据性能下降。
在自然语言处理顶级会议EMNLP 2023上发表的研究中,提出了一种改进的持续联邦学习方法。为防止灾难性遗忘,设备需保留已见数据的样本,新数据与旧样本合并后重新训练模型。
核心方法:梯度多样性样本选择
非协调策略
每个设备本地选择样本,通过优化梯度多样性实现:将梯度视为多维空间中的方向,选择使梯度总和趋近零的样本(即方向各异的梯度)。这转化为NP完全问题——从N个梯度中选择使和趋近零的组合。研究采用松弛方法,允许系数为分数但仍保持总和为N,最终选择系数最高的N个样本。
实验表明,当设备可存储50个以上样本时,非协调方法性能最佳;20个样本时则需要更精细的选择策略。
协调策略
中央服务器协调各设备样本选择:设备先本地优化梯度求和趋近零,将聚合梯度(非原始梯度以防逆向攻击)及系数发送至服务器;服务器调整各设备求和目标使其全局和趋近零,再将新目标返回设备。通常一次迭代即可使全局和接近零,设备最终选择系数最高的N个样本。
当存储容量为20个样本时,协调方法成为最佳选择。
实验验证
相比三种基线方法(均匀采样和两种加权采样),新方法在N=50和100时显著优于基线,N=20时协调版本最优。当N≤10时传统方法更具优势,但实际设备通常能存储更多样本。
研究为不同存储容量的设备提供了样本选择策略优化指南,推动联邦学习与持续学习的有效结合。
X轴表示每设备保留样本数,Y轴表示困惑度相对变化(概率分布预测效果指标)