简化基于BERT的模型以提高效率和容量
新方法使基于BERT的自然语言处理模型能够处理更长的文本字符串,在资源受限的环境中运行——有时甚至能同时实现两者。
背景
近年来,自然语言处理(NLP)领域许多性能最佳的模型都建立在BERT语言模型之上。BERT模型在大量(未标记)公共文本语料库上进行预训练,编码词序列的概率。由于BERT模型开始时对整体语言有广泛了解,因此可以用相对较少的标记数据对更针对性的任务(如问答或机器翻译)进行微调。
然而,BERT模型非常庞大,基于BERT的NLP模型可能运行缓慢——对于计算资源有限的用户来说,甚至慢到无法接受。其复杂性还限制了可处理输入的长度,因为其内存占用随输入长度的平方而缩放。
Pyramid-BERT架构
在今年的计算语言学协会(ACL)会议上,提出了一种名为Pyramid-BERT的新方法,该方法减少了基于BERT模型的训练时间、推理时间和内存占用,而准确性损失很小。减少的内存占用还使BERT模型能够处理更长的文本序列。
基于BERT的模型将句子序列作为输入,并输出整体句子及其组成词的向量表示(嵌入)。然而,下游应用(如文本分类和排序)仅使用完整句子嵌入。为了使基于BERT的模型更高效,该方法逐步消除网络中间层中的冗余单个词嵌入,同时尽量减少对完整句子嵌入的影响。
性能比较
将Pyramid-BERT与几种最先进的使BERT模型更高效的技术进行比较,结果显示,推理速度可提高3至3.5倍,而准确性仅下降1.5%,而在相同速度下,现有最佳方法的准确性损失为2.5%。
此外,当将该方法应用于Performers(专为长文本设计的BERT模型变体)时,可以将模型的内存占用减少70%,同时实际上提高准确性。在该压缩率下,现有最佳方法的准确性下降4%。
令牌处理流程
输入到BERT模型的每个句子被分解为称为令牌的单元。大多数令牌是单词,但有些是多词短语,有些是子词部分,有些是首字母缩略词的单个字母等。每个句子的开头由一个特殊令牌 demarcated(出于即将清楚的原因)称为CLS(分类)。
每个令牌通过一系列编码器(通常介于4到12个之间),每个编码器为每个输入令牌生成一个新的嵌入。每个编码器都有一个注意力机制,决定每个令牌的嵌入应反映其他令牌携带的信息的程度。
例如,给定句子“Bob told his brother that he was starting to get on his nerves”,注意力机制在编码“his”时应更多关注单词“Bob”,但在编码“he”时应更多关注“brother”。正是因为注意力机制必须将输入序列中的每个单词与其他每个单词进行比较,BERT模型的内存占用随输入长度的平方而缩放。
当令牌通过一系列编码器时,它们的嵌入会包含越来越多关于序列中其他令牌的信息,因为它们关注的其他令牌也在包含越来越多信息。当令牌通过最终编码器时,CLS令牌的嵌入最终代表整个句子(因此CLS令牌的名称)。但其嵌入也与句子中所有其他令牌的嵌入非常相似。这就是试图消除的冗余。
核心思想
基本思想是,在网络的每个编码器中,保留CLS令牌的嵌入,但选择其他令牌嵌入的代表性子集——核心集。
嵌入是向量,因此可以解释为多维空间中的点。为了构建核心集,理想情况下,将嵌入排序为等直径的簇,并选择每个簇的中心点——质心。
不幸的是,构建跨越神经网络层的核心集的问题是NP难的,意味着耗时不可行。
作为替代方案,论文提出一种贪心算法,一次选择核心集的n个成员。在每一层,取CLS令牌的嵌入,然后找到表示空间中离它最远的n个嵌入。将这些与CLS嵌入一起添加到核心集中。然后找到与核心集中已有点的最小距离最大的n个嵌入,并将这些添加到核心集中。
重复此过程,直到核心集达到所需大小。这被证明是优化核心集的足够近似。
最后,论文考虑了每层核心集应有多大的问题。使用指数延迟函数确定从一层到下一层的衰减程度,并研究选择不同衰减率时准确性与加速或内存减少之间的权衡。
致谢
Ashish Khetan, Rene Bidart, Zohar Karnin