Transformer模型中的线性层与激活函数解析

本文深入解析Transformer模型中线性层与激活函数的技术实现,包括前馈网络设计原理、GELU/SwiGLU等激活函数的数学特性,以及PyTorch代码实现细节,帮助理解非线性变换在注意力机制中的关键作用。

Transformer模型中的线性层与激活函数

注意力操作是Transformer模型的标志性特征,但它们并非唯一的构建模块。线性层和激活函数同样至关重要。本文将介绍:

  • 线性层和激活函数如何实现非线性变换
  • Transformer模型中前馈网络的典型设计
  • 常见激活函数及其特性

概述

本文分为三个部分:

  1. Transformer中为何需要线性层和激活函数
  2. 前馈网络的典型设计
  3. 激活函数的变体

Transformer中为何需要线性层和激活函数

注意力层是Transformer模型的核心功能。它对齐序列中的不同元素,并将输入序列转换为输出序列。注意力层执行输入的仿射变换,意味着输出是每个序列元素处输入的加权和。

神经网络的力量不仅来自线性层,还来自引入非线性的激活函数。在Transformer模型中,需要在注意力层后引入非线性以学习复杂模式。这是通过在每个注意力层后添加前馈网络(FFN)或多层感知网络(MLP)来实现的。典型的Transformer块结构如下:

上图中的灰色框在Transformer模型中重复多次。在每个块中(不包括归一化层),输入首先通过注意力层,然后通过前馈网络(在PyTorch中实现为nn.Linear)。前馈网络内的激活函数为变换添加非线性。

前馈网络使模型能够学习更复杂的模式。通常,它包含多个线性层:第一个扩展维度以探索不同表示,最后一个将其压缩回原始维度。激活函数通常应用于第一个线性层的输出。

基于这种设计,通常将块的前半部分称为"注意力子层",后半部分称为"MLP子层"。

前馈网络的典型设计

在BERT模型中,MLP子层实现如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
import torch.nn as nn

class BertMLP(nn.Module):
    def __init__(self, dim, intermediate_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, intermediate_dim)
        self.fc2 = nn.Linear(intermediate_dim, dim)
        self.gelu = nn.GELU()

    def forward(self, hidden_states):
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.gelu(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states

MLP子层包含两个线性模块。当输入序列进入MLP子层时,第一个线性模块扩展维度,然后应用GELU激活函数。结果通过第二个线性模块将维度压缩回原始大小。

中间维度通常是原始维度的4倍:这是Transformer模型中的常见设计模式。

激活函数的变体

激活函数为神经网络引入非线性,使其能够学习复杂模式。虽然传统神经网络常用双曲正切(tanh)、Sigmoid和整流线性单元(ReLU),但Transformer模型通常采用GELU和SwiGLU激活函数。

以下是一些常见激活函数的数学定义:

  • Sigmoid(x) = 1/(1+e⁻ˣ)
  • tanh(x) = (eˣ - e⁻ˣ)/(eˣ + e⁻ˣ) = 2Sigmoid(2x) - 1
  • ReLU(x) = max(0, x)
  • GELU(x) = x·Φ(x) ≈ x/2(1+tanh(√(2/π)(x+0.044715x³)))
  • Swishᵦ(x) = x·Sigmoid(βx) = x/(1+e⁻ᵝˣ)
  • SiLU(x) = x/(1+e⁻ˣ) = Swish₁(x)
  • SwiGLU(x) = SiLU(xW+b)·(xV+c)

ReLU(整流线性单元)在现代深度学习中很受欢迎,因为它避免了梯度消失问题且计算简单。

GELU(高斯误差线性单元)由于使用标准正态分布的累积分布函数Φ(x)而计算成本更高。如上所示存在近似公式。GELU不是单调的,如下图所示。

单调激活函数通常更受青睐,因为它们确保一致的梯度方向,可能带来更快的收敛。然而,单调性并非严格要求——可能只需要更长的训练时间。这是在模型复杂性和训练持续时间之间的权衡。

Swish是另一个非单调激活函数,具有控制x=0处斜率的参数β。当β=1时,称为SiLU(Sigmoid线性单元)。

SwiGLU(Swish门控线性单元)是现代Transformer模型中常见的最新激活函数。它是Swish函数和线性函数的乘积,参数在训练过程中学习。其受欢迎程度源于其复杂性:展开公式显示分子中有二次项,帮助模型学习复杂模式而无需额外层。

上图显示了这些激活函数的图表。所示的SwiGLU函数为f(x)=SiLU(x)·(x+1)。

在Python代码中切换激活函数很简单。PyTorch提供内置的nn.Sigmoid、nn.ReLU、nn.Tanh和nn.SiLU。然而,SwiGLU需要特殊实现。以下是Llama模型中使用的PyTorch代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
import torch.nn as nn

class LlamaMLP(nn.Module):
    def __init__(self, dim, intermediate_dim):
        super().__init__()
        self.gate_proj = nn.Linear(dim, intermediate_dim)
        self.up_proj = nn.Linear(dim, intermediate_dim)
        self.down_proj = nn.Linear(intermediate_dim, dim)
        self.act = nn.SiLU()

    def forward(self, hidden_states):
        gate = self.gate_proj(hidden_states)
        up = self.up_proj(hidden_states)
        swish = self.act(up)
        output = self.down_proj(swish * gate)
        return output

此实现使用两个线性层处理输入hidden_states。一个输出通过SiLU函数,然后与另一个输出相乘,最后通过一个线性层处理。线性层命名为"up"或"down"表示维度扩展/压缩,而连接到SiLU的层称为"gate"表示其门控机制。门控是神经网络的一种设计,意味着一个线性层的输出与权重的逐元素相乘,这里由Swish激活函数产生。

Llama模型架构如下所示,显示了MLP模块的双分支结构:

总结

本文介绍了Transformer模型中的线性层和激活函数。具体学习了:

  • 为什么线性层和激活函数对非线性变换是必要的
  • ReLU、GELU和SwiGLU激活函数的特性和实现
  • 如何构建Transformer模型中使用的完整前馈网络
comments powered by Disqus
使用 Hugo 构建
主题 StackJimmy 设计