Transformer模型中的线性层与激活函数
注意力操作是Transformer模型的标志性特征,但它们并非唯一的构建模块。线性层和激活函数同样至关重要。本文将介绍:
- 线性层和激活函数如何实现非线性变换
- Transformer模型中前馈网络的典型设计
- 常见激活函数及其特性
概述
本文分为三个部分:
- Transformer中为何需要线性层和激活函数
- 前馈网络的典型设计
- 激活函数的变体
Transformer中为何需要线性层和激活函数
注意力层是Transformer模型的核心功能。它对齐序列中的不同元素,并将输入序列转换为输出序列。注意力层执行输入的仿射变换,意味着输出是每个序列元素处输入的加权和。
神经网络的力量不仅来自线性层,还来自引入非线性的激活函数。在Transformer模型中,需要在注意力层后引入非线性以学习复杂模式。这是通过在每个注意力层后添加前馈网络(FFN)或多层感知网络(MLP)来实现的。典型的Transformer块结构如下:
上图中的灰色框在Transformer模型中重复多次。在每个块中(不包括归一化层),输入首先通过注意力层,然后通过前馈网络(在PyTorch中实现为nn.Linear)。前馈网络内的激活函数为变换添加非线性。
前馈网络使模型能够学习更复杂的模式。通常,它包含多个线性层:第一个扩展维度以探索不同表示,最后一个将其压缩回原始维度。激活函数通常应用于第一个线性层的输出。
基于这种设计,通常将块的前半部分称为"注意力子层",后半部分称为"MLP子层"。
前馈网络的典型设计
在BERT模型中,MLP子层实现如下:
|
|
MLP子层包含两个线性模块。当输入序列进入MLP子层时,第一个线性模块扩展维度,然后应用GELU激活函数。结果通过第二个线性模块将维度压缩回原始大小。
中间维度通常是原始维度的4倍:这是Transformer模型中的常见设计模式。
激活函数的变体
激活函数为神经网络引入非线性,使其能够学习复杂模式。虽然传统神经网络常用双曲正切(tanh)、Sigmoid和整流线性单元(ReLU),但Transformer模型通常采用GELU和SwiGLU激活函数。
以下是一些常见激活函数的数学定义:
- Sigmoid(x) = 1/(1+e⁻ˣ)
- tanh(x) = (eˣ - e⁻ˣ)/(eˣ + e⁻ˣ) = 2Sigmoid(2x) - 1
- ReLU(x) = max(0, x)
- GELU(x) = x·Φ(x) ≈ x/2(1+tanh(√(2/π)(x+0.044715x³)))
- Swishᵦ(x) = x·Sigmoid(βx) = x/(1+e⁻ᵝˣ)
- SiLU(x) = x/(1+e⁻ˣ) = Swish₁(x)
- SwiGLU(x) = SiLU(xW+b)·(xV+c)
ReLU(整流线性单元)在现代深度学习中很受欢迎,因为它避免了梯度消失问题且计算简单。
GELU(高斯误差线性单元)由于使用标准正态分布的累积分布函数Φ(x)而计算成本更高。如上所示存在近似公式。GELU不是单调的,如下图所示。
单调激活函数通常更受青睐,因为它们确保一致的梯度方向,可能带来更快的收敛。然而,单调性并非严格要求——可能只需要更长的训练时间。这是在模型复杂性和训练持续时间之间的权衡。
Swish是另一个非单调激活函数,具有控制x=0处斜率的参数β。当β=1时,称为SiLU(Sigmoid线性单元)。
SwiGLU(Swish门控线性单元)是现代Transformer模型中常见的最新激活函数。它是Swish函数和线性函数的乘积,参数在训练过程中学习。其受欢迎程度源于其复杂性:展开公式显示分子中有二次项,帮助模型学习复杂模式而无需额外层。
上图显示了这些激活函数的图表。所示的SwiGLU函数为f(x)=SiLU(x)·(x+1)。
在Python代码中切换激活函数很简单。PyTorch提供内置的nn.Sigmoid、nn.ReLU、nn.Tanh和nn.SiLU。然而,SwiGLU需要特殊实现。以下是Llama模型中使用的PyTorch代码:
|
|
此实现使用两个线性层处理输入hidden_states。一个输出通过SiLU函数,然后与另一个输出相乘,最后通过一个线性层处理。线性层命名为"up"或"down"表示维度扩展/压缩,而连接到SiLU的层称为"gate"表示其门控机制。门控是神经网络的一种设计,意味着一个线性层的输出与权重的逐元素相乘,这里由Swish激活函数产生。
Llama模型架构如下所示,显示了MLP模块的双分支结构:
总结
本文介绍了Transformer模型中的线性层和激活函数。具体学习了:
- 为什么线性层和激活函数对非线性变换是必要的
- ReLU、GELU和SwiGLU激活函数的特性和实现
- 如何构建Transformer模型中使用的完整前馈网络