推理时利用教师知识增强学生模型

本文介绍了一种名为检索增强知识蒸馏(ReAugKD)的新方法,通过在推理阶段利用教师模型的预测知识库来提升学生模型性能。该方法在六项自然语言处理任务中五项达到最优表现,平均仅带来3%的延迟开销,实现了知识蒸馏领域的新突破。

方法概述

检索增强知识蒸馏(ReAugKD)是一种通过利用教师模型在历史输入上生成的数据表示(嵌入)和预测结果来指导学生模型推理的框架。这些历史数据可存储在查询表中,用于为相似输入提供预测指导。该方法可适配任何特定任务的外部知识。

训练方法

ReAugKD采用两阶段训练流程:

  1. 投影层微调:在针对下游任务微调后的教师模型编码器顶部添加线性投影层,将嵌入向量投影至与学生模型编码器相同维度。使用监督对比损失微调参数,将相同标签的训练样本作为正样本,与批次中随机采样的负样本进行对比。

  2. 相似度矩阵构建:为训练数据生成教师嵌入向量和预测结果,创建教师嵌入间的相似度矩阵。通过最小化教师-教师相似度分布与教师-学生相似度分布之间的KL散度,确保推理时两者使用相同的相似度度量标准。损失函数同时包含计算学生预测与教师预测间交叉熵的项。

实验成果

在GLUE基准的六个数据集上,将12层BERT-Base模型蒸馏至6层BERT模型:

  • 在五项任务中达到最先进性能
  • 较之前最佳KD方法平均提升0.42%
  • 在两个基准任务上分别提升1.37%和1.43%
  • 检索版本较无检索版本提升0.45%,验证了检索增强的有效性

技术特点

  • 推理阶段聚合历史相似示例的教师预测,与学生预测结合输出最终结果
  • 仅带来3%的延迟开销
  • 适用于 paraphrasing、自然语言推理、问答等多种NLP任务

图示说明:在释义数据集示例中,通过检索教师对相似示例的知识显著改善了学生模型的预测精度,最终预测结合了学生预测分数和教师聚合预测分数。

comments powered by Disqus
使用 Hugo 构建
主题 StackJimmy 设计