KV缓存:实时大语言模型背后的隐藏加速技术

本文深入探讨了KV缓存技术如何通过重用键值对来避免重复计算,使大语言模型推理速度提升3-4倍。详细解析了自回归解码、注意力机制的工作原理,并提供了实际代码示例和性能对比数据。

KV缓存:实时大语言模型背后的隐藏加速技术

引言:为什么LLM性能很重要

你是否注意到你的AI助手开始时反应迅速,但随后…开始变得拖沓或变慢? 这不仅仅是你一个人的问题。这种减速是大语言模型(LLMs)工作方式的固有特性。大多数LLM使用称为自回归解码的方式逐个令牌生成文本。问题是——响应越长,模型在每个步骤需要做的工作就越多。因此延迟会累积。

现在,想象一下实际场景:

  • 你正在与支持机器人聊天,突然它需要很长时间才能回复
  • 你的代码自动补全工具在你流畅编码时开始卡顿
  • 语音助手在回应中途尴尬地暂停

这不是很好的体验。在底层,这也很昂贵,因为每一秒的延迟都会消耗更多的GPU(或任何硬件)时间、能源和金钱。

为什么会发生这种情况?

模型不仅仅在思考下一个词。它还在反复重新处理所有先前的词。就像每次想要在输出中添加一个词时,都要从头重新检查你的笔记。

这不仅是低效的,而且是不必要的。

这就是KV缓存的作用所在。这是一个令人惊讶的简单想法,却改变了一切。我们不是重新进行所有过去的计算,而是缓存我们已经看到的内容并重复使用它。结果是显著的。

在本文中,我将逐步介绍:

  • 为什么LLM首先依赖Transformer?
  • 逐个令牌解码如何减慢速度?
  • 注意力机制做什么,键和值如何成为瓶颈?
  • 最后,KV缓存如何在不偷工减料的情况下加速?

在实际部署中,启用KV缓存可以意味着生成速度快四倍,而质量没有任何下降。使用相同的模型和硬件。

什么是LLM以及为什么使用Transformer?

大语言模型(LLMs)只是经过训练预测下一个词的巨大神经网络。但它们的预测方式才是魔法发生的地方。

我不会过于深入架构,因为你可以找到许多其他关于这个主题的文章。现代LLM几乎完全依赖于称为Transformer的东西,而不是使用传统的架构,如RNN(循环神经网络)或LSTM(长短期记忆)。这是GPT、Claude、Mistral甚至BERT等模型背后的相同架构。

为什么Transformer占据了主导地位?因为它们非常擅长两件事:

  1. 在训练期间并行处理输入,这使得它们即使在大规模情况下也能快速训练
  2. 使用称为注意力的机制,让模型在决定下一步说什么时查看输入的每个部分

Transformer是一堆注意力块的堆叠,每个块接收一些数据,以巧妙的方式混合,然后向前传递。在每个块内部,有用于计算查询、键和值的组件。这些在注意力步骤中扮演重要角色,我们很快就会讲到。

但这里有个转折:即使Transformer在训练时很快,它们在推理时会变慢。在训练期间,你可以并行处理数千个令牌。但在推理时,你只能一次预测一个令牌。这就是问题开始出现的地方。

从Transformer到自回归解码

Transformer训练速度快是因为它们同时查看所有内容。但这不是它们运行推理的方式。

在生成文本时,语言模型不能并行预测所有令牌。在简单情况下,它们必须一次一个令牌地进行。这就是我们所说的自回归解码。

  • 第一步,模型看到提示并预测第一个令牌
  • 第二步,它看到提示加上第一个令牌并预测第二个

这里的关键思想是每个新令牌都依赖于它之前的所有内容。因此,生成变成了一个顺序过程,即使模型是并行训练的。

正如我之前所说,序列越长,生成下一个令牌所需的计算就越多。

注意力机制如何工作

让我们花点时间解析注意力机制。这是Transformer的核心。请阅读这篇论文以获取详细解释。

每个令牌被转换为三个向量:查询(Q)、键(K)和值(V)。

在每个步骤中,模型将当前查询与来自先前所有令牌的键进行比较。它计算每个过去令牌的相关性分数,然后使用这些分数对值进行加权。这就是注意力决定关注什么的方式。

在代码中看起来是这样的:

1
2
3
4
5
6
7
8
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(q, k, v):
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / d_k**0.5
    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, v)

这种机制很强大,因为它让模型动态决定关注上下文的哪些部分。但有一个问题。

在每个步骤中,模型需要使用所有先前的键和值重新计算注意力分数。这会很快累积起来。

KV缓存:避免重复计算

现在你已经看到了核心问题:每个新令牌都需要使用所有先前令牌计算注意力。这意味着更多的键、更多的值和更多的计算。这是线性内存,二次时间。

但这里有个技巧:你不需要为已经看到的令牌重新计算键和值。

相反,你缓存它们一次并在每个步骤中重复使用。

这就是KV缓存的精髓。

工作原理

  • 在每个解码步骤中,你只计算新令牌的Q
  • 你重复使用来自过去的K和V,从缓存中提取
  • 注意力变为:Q_t x [缓存的K],然后与[缓存的V]进行加权求和

没有重新编码。没有重新处理。只是获取和相乘。

这将二次时间的注意力循环转变为接近线性的东西,特别是与FlashAttention和分页KV内存等技巧结合时。

在代码中看起来如何?

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(q, k, v):
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / d_k**0.5
    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, v)

# 使用KV缓存模拟解码循环
past_k = []
past_v = []

for t in range(seq_len):  # 解码循环
    token_t = input_tokens[:, t]  # (batch,)
    
    # 嵌入和投影当前令牌
    x_t = embedding(token_t)  # (batch, d_model)
    q_t = query_proj(x_t)     # (batch, d_k)
    k_t = key_proj(x_t)       # (batch, d_k)
    v_t = value_proj(x_t)     # (batch, d_v)
    
    # 缓存K和V
    past_k.append(k_t)
    past_v.append(v_t)
    
    # 堆叠所有过去的K和V
    k_stack = torch.stack(past_k, dim=1)  # (batch, t+1, d_k)
    v_stack = torch.stack(past_v, dim=1)  # (batch, t+1, d_v)
    
    # 仅使用当前Q和缓存的K,V计算注意力
    q_t = q_t.unsqueeze(1)  # (batch, 1, d_k)
    out_t = scaled_dot_product_attention(q_t, k_stack, v_stack)
    
    # 生成令牌logits或继续循环...

来自真实实验的加速数字:

模型 GPU 无KV缓存 有KV缓存 加速比
GPT-2 (1.5B) A100 12 tok/sec 45 tok/sec 3.75x
GPT-J (6B) A100 5 tok/sec 20 tok/sec 4x
Mistral-7B L4 4.2 tok/sec 14.6 tok/sec 3.5x
LLaMA 13B A100 ~6 tok/sec ~25 tok/sec 4.1x

何时应该关注

让我们放大来看。KV缓存听起来像是一个低级优化,但它不是。它是一个基本的性能解锁,为今天几乎每个实时LLM部署提供动力。

最重要的场景:

  • 处理长对话的聊天机器人(ChatGPT、Claude等)
  • 代码协同工具如GitHub Copilot、Cursor和TabNine
  • 语音助手、翻译器和自动补全工具
  • 任何具有长提示或长输出的应用程序

如果你正在构建顺序生成令牌的工具,你应该假设KV缓存是必需的。不是可选的。

不太重要的场景:

  • 具有短输出的一次性推理
  • 非自回归任务(如分类或嵌入提取)
  • 训练时间(训练已经使用并行注意力)

最终思考

KV缓存不是一个小众技巧。它是区分工作原型和真实产品的关键。

comments powered by Disqus
使用 Hugo 构建
主题 StackJimmy 设计