DeepSeek-V3源码剖析:实现基于 PyTorch 的分布式 Transformer 文本生成模型
这段代码可以拆分为以下几个核心模块进行解析:作用& :用于读取环境变量和 JSON 配置文件。:解析命令行参数。:定义列表类型的输入和输出。& :用于加载模型、处理张量运算。支持分布式训练(如多 GPU 计算)。:加载分词器,将文本转换为 token。:加载模型参数(比传统的更安全)。:自定义的 Transformer 模型及其配置。2. 采样函数作用温度控制:通过影响采样随机性。高 → 采样更随
·
import os
import json
from argparse import ArgumentParser
from typing import List
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model
from model import Transformer, ModelArgs
def sample(logits, temperature: float = 1.0):
"""
Samples a token from the logits using temperature scaling.
Args:
logits (torch.Tensor): The logits tensor for token predictions.
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
Returns:
torch.Tensor: The sampled token.
"""
logits = logits / max(temperature, 1e-5)
probs = torch.softmax(logits, dim=-1)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
@torch.inference_mode()
def generate(
model: Transformer,
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0
) -> List[List[int]]:
"""
Generates new tokens based on the given prompt tokens using the specified model.
Args:
model (Transformer): The transformer model used for token generation.
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
max_new_tokens (int): The maximum number of new tokens to generate.
eos_id (int): The end-of-sequence token ID.
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
Returns:
List[List[int]]: A list of lists containing the generated tokens for each sequence.
"""
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos
if finished.all():
break
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks)
return completion_tokens
def main(
ckpt_path: str,
config: str,
input_file: str = "",
interactive: bool = True,
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> None:
"""
Main function to load the model and perform interactive or batch text generation.
Args:
ckpt_path (str): Path to the model checkpoint directory.
config (str): Path to the model configuration file.
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
"""
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
global print
if rank != 0:
print = lambda *_, **__: None
torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(965)
with open(config) as f:
args = ModelArgs(**json.load(f))
print(args)
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
if interactive:
messages = []
while True:
if world_size == 1:
prompt = input(">>> ")
elif rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
print(completion)
messages.append({"role": "assistant", "content": completion})
else:
with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
print()
if world_size > 1:
dist.destroy_process_group()
if __name__ == "__main__":
"""
Command-line interface for distributed text generation.
Arguments:
--ckpt-path (str): Path to the model checkpoint directory.
--config (str): Path to the model configuration file.
--input-file (str, optional): File containing prompts for batch processing.
--interactive (bool, optional): Enable interactive mode for generating text.
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
Raises:
AssertionError: If neither input-file nor interactive mode is specified.
"""
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.2)
args = parser.parse_args()
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
这段代码可以拆分为以下几个核心模块进行解析:
1. 依赖导入
import os
import json
from argparse import ArgumentParser
from typing import List
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model
from model import Transformer, ModelArgs
作用
os&json:用于读取环境变量和 JSON 配置文件。argparse:解析命令行参数。typing.List:定义列表类型的输入和输出。torch&torch.distributed:torch用于加载模型、处理张量运算。torch.distributed支持分布式训练(如多 GPU 计算)。
transformers.AutoTokenizer:加载分词器,将文本转换为 token。safetensors.torch.load_model:加载模型参数(比传统的.pt更安全)。model.Transformer, ModelArgs:自定义的 Transformer 模型及其配置。
2. 采样函数
def sample(logits, temperature: float = 1.0):
"""
通过温度参数对 logits 进行采样,以生成下一个 token。
Args:
logits (torch.Tensor): 预测的 logits。
temperature (float): 采样温度,决定随机性。
Returns:
torch.Tensor: 选出的 token。
"""
logits = logits / max(temperature, 1e-5) # 避免除零
probs = torch.softmax(logits, dim=-1) # 转换为概率分布
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
作用
- 温度控制:通过
temperature影响采样随机性。temperature高 → 采样更随机(创造性更强)。temperature低 → 采样更确定(趋向于选取最大概率)。
softmax转换 logits 为概率分布。- 采用 Gumbel-Softmax 采样(避免直接
argmax选择最大概率 token)。
3. 文本生成函数
@torch.inference_mode()
def generate(
model: Transformer,
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0
) -> List[List[int]]:
作用
- 根据
prompt_tokens生成max_new_tokens数量的 token。 - 参数说明
model:Transformer 模型。prompt_tokens:输入的 token 列表(批量)。max_new_tokens:最大新生成 token 数量。eos_id:终止 token ID(例如</s>)。temperature:控制采样的随机性。
3.1. 预处理
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len, f"超出模型最大长度 {model.max_seq_len}"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
- 计算 prompt 长度,确保不超过
model.max_seq_len。 - 生成 填充 token 张量(-1 表示未填充)。
- 将
prompt_tokens复制到tokens中。
3.2. 生成过程
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
prompt_mask = tokens != -1
prev_pos:当前生成的起始位置。finished:跟踪哪些序列已生成eos_id。prompt_mask:指示哪些 token 是 prompt(不可修改)。
for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1) # 直接选最大概率 token
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos
if finished.all():
break
- 依次生成 token:
- 获取
logits并采样。 - 确保 prompt 部分不被修改。
- 如果遇到
eos_id,标记finished。 - 提前终止:所有序列都完成时,结束循环。
- 获取
4. 主函数
def main(ckpt_path, config, input_file="", interactive=True, max_new_tokens=100, temperature=1.0):
作用
- 读取模型,支持 交互式 和 批量处理。
4.1. 分布式计算
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16)
torch.manual_seed(965)
- 分布式训练
WORLD_SIZE > 1时,初始化NCCL后端进行分布式计算。rank != 0的进程屏蔽print(只让rank=0负责打印)。bfloat16:减少显存占用,加速计算。
4.2. 加载模型
with open(config) as f:
args = ModelArgs(**json.load(f))
with torch.device("cuda"):
model = Transformer(args)
- 解析
config并实例化Transformer模型。
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
- 通过
safetensors加载模型参数。
4.3. 交互式模式
if interactive:
messages = []
while True:
prompt = input(">>> ")
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
print(completion)
messages.append({"role": "assistant", "content": completion})
- 支持对话,并记住历史消息。
apply_chat_template:将messages转换为输入格式。- 生成并打印
completion。
5. 入口点
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
...
args = parser.parse_args()
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
- 解析命令行参数,调用
main()。
总结
这段代码的核心是:
- 分布式环境初始化
- 加载模型
- 提供交互和批量处理
- 通过 Transformer 生成文本
这个 Python 脚本是一个用于文本生成的程序,结合了 PyTorch 和 Hugging Face Transformers 库,并支持分布式计算。它的主要功能是:
-
加载 Transformer 模型
- 使用
safetensors加载权重。 - 通过
AutoTokenizer处理文本输入和输出。
- 使用
-
生成文本
generate函数使用 Transformer 模型,根据输入的 token 生成新的 token。- 支持温度采样 (
temperature) 控制随机性,支持终止符 (eos_id) 结束生成。
-
支持交互式模式和批量模式
- 交互式模式下,用户可以输入文本,模型进行续写,并可清空对话历史。
- 批量模式下,程序从
input_file读取多个输入,生成对应的文本。
-
分布式计算
- 支持 PyTorch
torch.distributed进行多 GPU 训练。 - 通过
os.getenv("WORLD_SIZE")获取集群大小,初始化nccl后端。
- 支持 PyTorch
-
命令行参数
- 通过
argparse解析参数,如--ckpt-path指定模型路径,--config指定配置文件等。
- 通过
可能的问题:
generate方法可能会因prev_pos的更新方式导致索引超出界限。tokens[:, cur_pos] = next_token可能在cur_pos超出tokens.shape[1]时引发索引错误。logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)假设forward方法能处理prev_pos这种输入,但具体实现未知,可能需要调整。
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐

所有评论(0)