做大模型优化,transformer架构的深刻理解十分重要。

索性直接手写一些关键代码和可视化结果,帮助理解整体流程。

先来个大模型整体架构流程:

用户输入:"The capital of France is" (用户输入的这个就是提示词)┌─────────────────────────────────────────┐│        Prefill 阶段                     │├─────────────────────────────────────────┤│ 一次性处理整个输入序列                  ││ 并计算 KV cache                        │├─────────────────────────────────────────┤│ 输入: [CLS] The capital of France is   ││ 长度: 6 tokens                          ││                                         ││ Forward pass:                           ││   - 计算所有 token 的 embeddings       ││   - 通过所有 Transformer 层            ││   - 保存 KV cache                      ││   - 输出最后一个 token 的隐藏状态     │└─────────────────────────────────────────┘         ↓┌─────────────────────────────────────────┐│      Decode 阶段(开始生成)            │├─────────────────────────────────────────┤│ 基于 prefill 结果生成新 token          ││ (利用之前保存的 KV cache)              │└─────────────────────────────────────────┘

什么是 Prefill?

Prefill = 预填充:在生成新 token 之前,先处理输入的提示词(prompt)。Prefill 的详细过程

步骤 1:Tokenization(分词)

from transformers import AutoTokenizertokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b")prompt = "The capital of France is"# 分词tokens = tokenizer.encode(prompt, return_tensors="pt")print(tokens)# 输出: tensor([[    1,  1576,   7483,   310, 8363,  5799]])#                   ↑    ↑        ↑       ↑    ↑      ↑    #                   BOS  The     capital   of  France is  # 为了方便理解,没有把单个词拆分,比如tokenizer可能会把captial -> cap ital 拆为两个tokenseq_len = 6  # 实际长度

步骤 2:Embedding 层

# 输入: token idsinput_ids = tensor([1, 1576, 7483, 310, 8363, 5799, 338])# Embedding 层将每个 token id 转换为向量embeddings = embedding_layer(input_ids)print(embeddings.shape)# torch.Size([1, 6, 768])#             ↑  ↑  ↑#             |  |  └─ hidden dimension (隐藏维度)#             |  └───── seq_len (序列长度 = 6)#             └──────── batch_size (批大小)# 具体值# embeddings[0, 0, :] = [0.1, 0.2, ..., 0.5]  # token "The" 的向量# embeddings[0, 1, :] = [0.3, 0.1, ..., 0.7]  # token "capital" 的向量# embeddings[0, 2, :] = [0.2, 0.4, ..., 0.1]  # token "of" 的向量# ...# embeddings[0, 5, :] = [0.4, 0.5, ..., 0.3]  # token "is" 的向量

步骤 3:位置编码(Positional Encoding)

# Transformer 需要知道每个 token 的位置# 原因:为了区分语义相同但位置不同的情况position_ids = torch.arange(seq_len).unsqueeze(0)  # [1, 6]# position_ids = [[0, 1, 2, 3, 4, 5]]pos_embeddings = positional_encoding(position_ids)# pos_embeddings.shape = [1, 6, 768]# 加上位置编码embeddings = embeddings + pos_embeddings# 现在 embeddings 既包含 token 的语义,也包含位置信息

步骤 4:通过第一个 Transformer 层

# 一个标准的 Transformer 层包含:# 1. Multi-Head Attention# 2. Feed Forward Network (FFN)# 3. Layer Normalization# 4. Residual Connections# ============ Self-Attention ============# 对所有 token 计算 Q, K, Vquery = embeddings @ W_q    # [1, 6, 768]key = embeddings @ W_k      # [1, 6, 768]value = embeddings @ W_v    # [1, 6, 768]# 计算 attention weights# Q @ K^T 的结果是一个 6×6 的矩阵# 因为:#   Q: [1, 6, 768]#   K^T: [768, 6]#   结果: [1, 6, 6]scores = query @ key.transpose(-2, -1) / sqrt(d_k)# scores[0] 是一个 6×6 矩阵:# scores[0, i, j] = token_i 对 token_j 的注意力强度# 可视化:# token:     The  capital  of  France  is#      ┌───────────────────────────────────┐# The  │ 0.2   0.1    0.05  0.3    0.35  │# cap  │ 0.15  0.3    0.1   0.25   0.2   │# of   │ 0.1   0.15   0.4   0.2    0.15  │# Fra  │ 0.25  0.2    0.15  0.25   0.15  │# is   │ 0.3   0.15   0.1   0.2    0.25  │#      └───────────────────────────────────┘## 含义:#   - "is" (最后一行) 最关注 "The" (第一列 0.3)#   - "capital" (第二行) 最关注自己 (对角线 0.3)#   - 等等# 应用 softmax 和加权求和attention_weights = softmax(scores, dim=-1)  # [1, 6, 6]attn_output = attention_weights @ value      # [1, 6, 768]# ============ FFN ============ffn_output = FFN(attn_output)  # [1, 6, 768]# ============ Residual + LayerNorm ============layer_output = LayerNorm(attn_output + ffn_output)  # [1, 6, 768]

关键观察:

# 在 prefill 的单个 forward 中,所有 token 一起被处理# 计算量:O(seq_len^2) 因为 attention 需要计算所有 token 对之间的关系# 优点:#   - 并行度高,充分利用 GPU#   - 相对较快# 缺点:#   - seq_len 较大时,内存爆炸 (O(seq_len^2))#   - 无法使用 CUDA Graph (seq_len 在单个 forward 内固定,但无法预知下一个输入的 seq_len)

我们看到以上是全序列Q都参与了softmax的计算,但在实际推理prefill阶段,我们只需要最后一个token去进行atten计算,前n-1个token可以不用进行计算。

推理中,prefill 阶段,大模型在每一层 Transformer 做注意力时,只需要最后一个位置的 Q 和整段 prompt 的 K/V 做 softmax;

但在此之前,每层对所有 token 仍然要算 hidden state 和 K/V,并把 K/V 缓存起来。

对上述Transformer层做优化:

# 一个标准的 Transformer 层包含:# 1. Multi-Head Attention# 2. Feed Forward Network (FFN)# 3. Layer Normalization# 4. Residual Connections# ============ Self-Attention ============# def forward(self, x):embeddings=self.embedding(x)# [B, L, H]# 对所有 token 计算 Q, K, Vquery=embeddings@W_q# [1, 6, 768]key=embeddings@W_k# [1, 6, 768]value=embeddings@W_v# [1, 6, 768]# 计算 attention weights# Q @ K^T 的结果是一个 6×6 的矩阵# 因为:#   Q: [1, 6, 768]#   K^T: [768, 6]#   结果: [1, 6, 6]# 只取最后一个位置token 做attenquery=query[:-1:]scores=query@key.transpose(-2,-1)/sqrt(d_k)# scores[0] 是一个 6×6 矩阵:# scores[0, i, j] = token_i 对 token_j 的注意力强度# 可视化:# token:      is#      ┌───────────────────────────────────┐# The  │    0.35  │# cap  │    0.2   │# of   │    0.15  │# Fra  │    0.15  │# is   │    0.25  │#      └───────────────────────────────────┘## 含义:#   - "is" (最后一行) 最关注 "The" (第一列 0.3)# 应用 softmax 和加权求和attention_weights=softmax(scores,dim=-1)# [1, 1, 6]attn_output=attention_weights@value# [1, 1, 768]# ============ FFN ============ffn_output=FFN(attn_output)# [1, 1, 768]# ============ Residual + LayerNorm ============layer_output=LayerNorm(attn_output+ffn_output)# [1, 1, 768]out=x.clone()# 直接拷贝输入out[:,-1:,:]=layer_output# [B, L, H]

步骤 5:通过所有 Transformer 层

# Prefill 不同于 Decode 的关键区别# Prefill 会通过所有 N 个 Transformer 层x = embeddings  # [1, 6, 768]for layer_idx in range(num_layers):  # 比如 32 层    layer = transformer_layers[layer_idx]    x = layer(x)  # x 始终是 [1, 6, 768]# 最后的 hidden statefinal_hidden = x  # [1, 6, 768]

步骤 6:计算 KV Cache

# Prefill 的一个重要作用就是预计算所有 token 的 KV# 为后续的 decode 阶段做准备# 在每一层,我们保存该层的 K 和 Vkv_cache = {}  # dict: layer_idx -> (key, value)for layer_idx in range(num_layers):    # 假设我们在 prefill forward 的第 layer_idx 层    # 计算了这一层的 Q, K, V    key = ...    # [1, 6, 768]    value = ...  # [1, 6, 768]    # 保存到 cache    kv_cache[layer_idx] = (key, value)# 保存后的 kv_cache 结构# {#   0: (key_layer0, value_layer0),  # layer 0 的 KV#   1: (key_layer1, value_layer1),  # layer 1 的 KV#   ...#   31: (key_layer31, value_layer31),  # layer 31 的 KV# }

步骤 7:获取输出 logits

# 只需要最后一个 token 的隐藏状态来生成下一个 tokenlast_hidden = final_hidden[:, -1, :]  # [1, 768]#                         ↑ 取最后一个位置# 通过 language modeling headlogits = lm_head(last_hidden)  # [1, vocab_size]# 比如 vocab_size = 32000print(logits.shape)  # torch.Size([1, 32000])# 采样或贪心选择下一个 tokennext_token_id = argmax(logits)  # scalar# 或者按概率采样next_token_id = sample(logits)

注意:
在推理阶段,如果为了最大化吞吐,prefill 阶段的注意力一般只需要最后一个 token 的 Q,与全量 K/V 计算一次 attention;
K/V 仍对全序列并行计算,而不需要为全序列都生成 Q 和 attention。

这是 vLLM、TensorRT-LLM 等高性能推理框架的核心思路之一,实际部署推理只取最后一个Q,得出logits,提高吞吐量。

Prefill 的可视化流程

输入提示词: "The capital of France is"┌─────────────────────────────────────┐│ Token:    The capital of France is  ││ Position: 0   1      2  3      4    │ ← seq_len = 5└─────────────────────────────────────┘              ↓    ┌──────────────────────────┐    │  Tokenization & Embedding │    └──────────────────────────┘              ↓    ┌──────────────────────────┐    │  Layer 0 (Self-Attention) │    │  所有 5 个 token 一起处理 │    └──────────────────────────┘              ↓    ┌──────────────────────────┐    │  Layer 1 (Self-Attention) │    │  所有 5 个 token 一起处理 │    └──────────────────────────┘              ↓           ...  (重复 32 层)              ↓    ┌──────────────────────────┐    │ Layer 31 (Self-Attention) │    │ 所有 5 个 token 一起处理 │    └──────────────────────────┘              ↓    ┌────────────────────────┐    │ LM Head (最后一层)     │    │ 只取最后一个 token     │    │ 的输出: "is" -> logits │    └────────────────────────┘              ↓    ┌────────────────────────┐    │ 采样下一个 token: "Paris"    └────────────────────────┘

什么是 Decoder 阶段?

Decoder = 解码/生成阶段:在处理完输入提示词(Prefill)之后,逐个生成新的 token。

Prefill 阶段完成后:"The capital of France is"         ↓    已经准备好了 KV cache         ↓┌──────────────────────────────────┐│    Decoder 阶段(生成阶段)      │├──────────────────────────────────┤│ Step 0: 生成第 1 个新 token     ││ Step 1: 生成第 2 个新 token     ││ Step 2: 生成第 3 个新 token     ││ ...                              ││ Step N: 生成第 N 个新 token     │└──────────────────────────────────┘

Decoder 的详细执行流程

Step 0:生成第一个 token

准备状态

# Prefill 之后我们有:# ┌─────────────────────────────────────────────────────┐# │        Prefill 完成时的状态快照                    │# ├─────────────────────────────────────────────────────┤# │                                                     │# │ 1. 输入提示词:                                    │# │    "The capital of France is"                      │# │                                                     │# │ 2. Token IDs:                                     │# │    [1576, 7483, 310, 8363, 5799, 338]             │# │    seq_len = 6                                      │# │                                                     │# │ 3. Logits:                                        │# │    shape: [1, 6, 32000]                           │# │    最重要的是最后一个:logits[:, -1, :] = [1, 32000]# │    这表示:基于所有前面的信息,下一个 token 的概率分布# │                                                     │# │ 4. KV Cache:                                      │# │    32 层 × 2 (key + value)                         │# │    每层的形状: [1, 6, 768]                         │# │    存储了所有 6 个输入 token 的 key 和 value      │# │                                                     │# └─────────────────────────────────────────────────────┘# 从 prefill 的最后一个位置的 logits 采样next_token_logits = logits[:, -1, :]  # [1, 32000]# 方法 A:贪心(取最高概率)next_token_id = torch.argmax(next_token_logits, dim=-1)print(f"Next token ID: {next_token_id.item()}")# 输出:4521 (比如这是 "Paris" 的 token ID)# 方法 B:采样(按概率分布采样)next_token_id = torch.multinomial(    torch.softmax(next_token_logits, dim=-1),     num_samples=1)# 输出:也可能是 4521,但也可能是其他高概率的 token# 方法 C:top-k 采样(只从概率最高的 k 个 token 中采样)# vLLM 等推理框架会使用这种方法# 让我们假设采样结果是:next_token_id = 4521  # "Paris"# 解码一下看看是什么token_text = tokenizer.decode([next_token_id])print(f"Generated token: '{token_text}'")# 输出:'Paris'

Forward Pass

# 输入:前一个token(第6个token)的隐藏状态# hidden: [1, 768] - prefill最后一个token的隐藏状态for layer_idx in range(num_layers):    layer = transformer_layers[layer_idx]    # 类似的过程    query = hidden @ W_q[layer_idx] # [1, 768]    new_key = hidden @ W_k[layer_idx] # [1, 768]    new_value = hidden @ W_v[layer_idx] # [1, 768]    cached_key, cached_value = kv_cache[layer_idx] # [1, seq_len, 768]    # 3. 将新的key/value 与缓存拼接    new_key_expanded = new_key.unsqueeze(1)      # [1, 1, 768]    new_value_expanded = new_value.unsqueeze(1)  # [1, 1, 768]    full_keys = torch.cat([cached_keys, new_key_expanded], dim=1)      # [1, seq_len+1, 768]    full_values = torch.cat([cached_values, new_value_expanded], dim=1) # [1, seq_len+1, 768]    query_expanded = query.unsqueeze(1) # [1, 1, 768]    # 4. 注意力分数    scores = query_expanded @ full_key.transpose(-2, -1) / sqrt(d_k) # [1, 1, seq_len + 1]    # causal mask, 确保只看到当前位置及之前的tokens    mask = torch.tril(torch.ones(1, 1, seq_len+1, seq_len+1))    scores = scores.masked_fill(mask == 0, float('-inf'))    attn_weights = softmax(scores, dim=-1) # [1, 1, 768]    # 加权求和    attn_output = attn_weights @ full_values # [1, 1, 768]    # 5. 更新隐藏状态 (移除seq维度用于后续处理)    hidden = attn_output.squeeze(1) # [1, 768]    # 6. 更新KV cache    kv_cache[layer_idx] = (full_key, full_value)    # 7. 通过FFN等其他组件    hidden = layer.ffn(hidden)    hidden = layer.norm2(hidden)# 最终输出logits_step0 = lm_head(hidden)  # [1, vocab_size]# 采样下一个 tokennext_token_id = sample(logits_step0)  # 比如 4521 表示 "Paris"

状态更新

# 完成 Step 0 后的状态:# 1. 生成的 token 列表generated_tokens = [4521]  # "Paris"# 2. 更新的 seq_lencurrent_seq_len = 7  # 原来 6 + 新生成 1# 3. 更新的 KV cache(已经自动更新了)kv_cache = {    layer_0: (key_0, value_0),      # [1, 7, 768] ← 从 6 变为 7    layer_1: (key_1, value_1),      # [1, 7, 768]    ...    layer_31: (key_31, value_31),   # [1, 7, 768]}

Step 1:生成第二个 token

现在输入变成了 “The capital of France is Paris”,Step 2, 3, … 类似重复上述过程。

# 每一步都重复类似的过程:# 1. 获取上一步生成的 token# 2. embedding 它# 3. 通过所有层,使用更新的 KV cache# 4. 生成新 token# 5. 更新 KV cache

Decoder 的可视化流程

初始状态(Prefill 完成):┌─────────────────────────────────────────┐│ 序列:The capital of France is          ││ seq_len = 6                             ││ KV cache 尺寸: [1, 6, 768] × 32 层      │└─────────────────────────────────────────┘         ↓Step 0:┌─────────────────────────────────────────┐│ 输入:[is] (最后一个 token)              ││ ↓ embedding & layers                    ││ Attention with 6 + 1 = 7 positions      ││ ↓ 生成                                   ││ 输出:Paris                              ││ KV cache 更新: [1, 7, 768] × 32 层      │└─────────────────────────────────────────┘...

各类算子在 prefill 阶段的并行性

2.1 Embedding / RoPE / Position encoding

以最简单的 embedding 为例:

python# ids: [B, L]x = embedding_table[ids]  # [B, L, H]x = x + positional_encoding[:, :L, :]  # 或 RoPE
  • 这一步本质上是对 B * L 个 token 并行查表 / 加法;
  • 是一个典型的“大批量元素级运算(element-wise op)”。

对 decode 来说,每步只对 B * 1 个 token 做同样的操作,单次并行元素个数少很多。

2.2 LayerNorm / RMSNorm

prefill 时,每层的 LN 通常这样:

pythonx_norm = layernorm(x)  # x: [B, L, H]
  • LN 的标准实现是对最后一维 H 做归一化;
  • 但对 [B, L, H] 来说,其实等价于对 B * L 个长度为 H 的向量并行做同样的操作;
  • 改成 decode 时,每步只对 B * 1 个向量做 LN。

也就是说:

  • prefill:一次 LN seeing B * L 个 token;
  • decode:一次 LN seeing B 个 token;

算子内部的并行粒度立刻差了 L 倍。

2.3 Attention 本身的差异

这里有两个层次:

  1. prefill 的 Q 只用最后一个位置

    (高效实现)

  2. 但 K/V 是对全序列并行生成并缓存

对于 K/V

python# prefillK = x @ W_K     # [B, L, H] -> [B, L, H]V = x @ W_V     # 同上# 写入 KV cache,一次写 L 个位置
  • 这仍然是用前面说的大 GEMM;
  • 再加一次“大块连续写”到 KV cache,非常利于带宽利用。

对于 Q + Attention(prefill 中只算最后一个 token 的):

pythonq_last = Q[:, -1:, :]        # [B, 1, H]scores = q_last @ K.transpose(-2, -1) / sqrt(d)  # [B, 1, L]attn   = softmax(scores, dim=-1)                 # [B, 1, L]out    = attn @ V                                # [B, 1, H]
  • 这里计算量其实和 decode 中每一步的一次 attention 非常接近(都是 O(L * H) 级别);
  • 并行度主要来自 batch 维和 head 维。

prefill 的高并行,真正主要在「K/V+线性+FFN」这块。

读者福利:如果大家对大模型感兴趣,这套大模型学习资料一定对你有用

对于0基础小白入门:

如果你是零基础小白,想快速入门大模型是可以考虑的。

一方面是学习时间相对较短,学习内容更全面更集中。
二方面是可以根据这些资料规划好学习计划和方向。

作为一名老互联网人,看着AI越来越火,也总想为大家做点啥。干脆把我这几年整理的AI大模型干货全拿出来了。
包括入门指南、学习路径图、精选书籍、视频课,还有我录的一些实战讲解。全部免费,不搞虚的。
学习从来都是自己的事,我能做的就是帮你把路铺平一点。资料都放在下面了,有需要的直接拿,能用到多少就看你自己了。

这份完整版的大模型 AI 学习资料已经上传CSDN,朋友们如果需要可以点击文章最下方的VX名片免费领取【保真100%】
在这里插入图片描述

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐