方法概述
知识蒸馏(KD)是在低延迟需求环境中部署大规模语言模型的最有效方法之一,其通过将大规模教师模型的知识转移到小型学生模型来实现。学生模型虽因体积小而更高效,但性能通常弱于教师模型。
在计算语言学协会(ACL)会议上提出的检索增强知识蒸馏(ReAugKD)框架,通过利用教师模型的能力来提升学生模型性能,且仅产生最小延迟开销。该方法使用教师模型对历史输入生成的数据表示(嵌入)和预测(可存储于查询表中),来指导学生模型对相似输入的预测。该方案原则上可适配任何任务特定的外部知识。
训练方法
ReAugKD采用两阶段训练流程:
-
第一阶段:基于针对下游任务微调后的教师模型,在其编码器顶部添加线性投影层,将编码器嵌入(输入数据的向量表示)投影至与学生模型编码器相同的维度。使用监督对比损失微调线性投影层参数,将相同标签的训练样本作为正样本,与批次中随机采样的负样本进行对比。
-
第二阶段:为训练学生模型的输入数据生成(调整尺寸后的)教师嵌入和教师预测,创建教师嵌入的相似度矩阵(衡量各输入嵌入间的相似性)。训练学生模型时,创建学生嵌入与教师嵌入的相似度矩阵,并通过损失函数最小化教师-教师相似度分布与教师-学生相似度分布间的KL散度。这确保推理时在知识库中搜索与学生当前输入相似的教师嵌入时,双方使用相同的相似性度量标准。损失函数还包含通过交叉熵损失计算学生预测与教师预测间差异的项。
实验与结果
在测试中,使用ReAugKD将12层BERT-Base模型蒸馏为6层BERT模型,并在GLUE基准的六个数据集上评估性能。该方法在五个数据集上达到最优结果,较之前最佳KD方法平均提升0.42%,在两个基准任务上分别提升1.37%和1.43%。采用知识库检索的ReAugKD版本较无检索版本提升0.45%,验证了检索增强的有效性。
推理阶段
在ReAugKD的推理阶段,聚合教师模型对与当前样本最相似历史样本的预测,并将其与学生预测相结合,形成最终输出。
该方法在保持低延迟的同时显著提升模型性能,为资源受限环境下的模型部署提供了有效解决方案。