简化基于BERT的模型以提高效率和容量
新方法使基于BERT的自然语言处理模型能够处理更长的文本字符串,在资源受限的环境中运行——有时甚至能同时实现这两个目标。
近年来,自然语言处理(NLP)领域许多性能最佳的模型都建立在BERT语言模型之上。BERT模型在大量(未标记)公共文本语料库上进行预训练,编码单词序列的概率。由于BERT模型始于对语言整体的广泛知识,它可以用相对较少的标记数据针对更具体的任务(如问答或机器翻译)进行微调。
然而,BERT模型非常庞大,基于BERT的NLP模型可能运行缓慢——对于计算资源有限的用户来说甚至慢到无法接受。它们的复杂性也限制了可处理输入的长度,因为其内存占用随输入长度的平方而缩放。
在今年计算语言学协会(ACL)会议上,同事和作者提出了一种称为Pyramid-BERT的新方法,减少了基于BERT模型的训练时间、推理时间和内存占用,而精度损失很小。减少的内存占用还使BERT模型能够处理更长的文本序列。
基于BERT的模型将句子序列作为输入,并输出整体句子及其组成单词的向量表示(嵌入)。然而,下游应用(如文本分类和排名)仅使用完整句子嵌入。为了使基于BERT的模型更高效,作者逐步消除网络中间层中冗余的单个单词嵌入,同时尽量减少对完整句子嵌入的影响。
将Pyramid-BERT与几种最先进的BERT模型效率提升技术进行比较,结果显示可以在精度仅下降1.5%的情况下将推理速度提高3到3.5倍,而在相同速度下,现有最佳方法精度损失2.5%。
此外,当将该方法应用于Performers(专为长文本设计的BERT模型变体)时,可以将模型的内存占用减少70%,同时实际提高精度。在该压缩率下,现有最佳方法精度下降4%。
令牌的处理过程
输入到BERT模型的每个句子被分解为称为令牌的单元。大多数令牌是单词,但有些是多词短语,有些是子词部分,有些是首字母缩略词的单个字母等。每个句子的开头由一个特殊令牌 demarcated(原因很快就会清楚)CLS(分类)标记。
每个令牌通过一系列编码器(通常介于4到12个之间),每个编码器为每个输入令牌生成新的嵌入。每个编码器都有一个注意力机制,决定每个令牌的嵌入应反映其他令牌携带信息的多少。
例如,给定句子“Bob told his brother that he was starting to get on his nerves”,注意力机制在编码单词“his”时应更多关注单词“Bob”,但在编码单词“he”时应更多关注“brother”。正是因为注意力机制必须将输入序列中的每个单词与其他每个单词进行比较,BERT模型的内存占用随输入长度的平方而缩放。
当令牌通过一系列编码器时,它们的嵌入考虑了序列中其他令牌的越来越多信息,因为它们关注的其他令牌也在考虑越来越多信息。当令牌通过最终编码器时,CLS令牌的嵌入最终代表整个句子(因此CLS令牌得名)。但它的嵌入也与句子中所有其他令牌的嵌入非常相似。这就是作者试图消除的冗余。
基本思想是,在网络的每个编码器中,保留CLS令牌的嵌入,但选择其他令牌嵌入的代表性子集——核心集。
嵌入是向量,因此可以解释为多维空间中的点。为了构建核心集,理想情况下将嵌入排序为等直径的簇,并选择每个簇的中心点——质心。
不幸的是,构建跨越神经网络层的核心集的问题是NP难的,意味着耗时不可行。
作为替代方案,论文提出一种贪心算法,每次选择核心集的n个成员。在每一层,取CLS令牌的嵌入,然后在表示空间中找到离它最远的n个嵌入。将这些与CLS嵌入一起添加到核心集。然后找到与核心集中已有点的最小距离最大的n个嵌入,并将这些添加到核心集。
重复此过程直到核心集达到所需大小。这被证明是最优核心集的足够近似。
最后,在论文中考虑了每层核心集应有多大的问题。使用指数延迟函数确定从一层到下一层的衰减程度,并研究选择不同衰减率时精度与加速或内存减少之间的权衡。
致谢:Ashish Khetan, Rene Bidart, Zohar Karnin