import os
import json
from argparse import ArgumentParser
from typing import List

import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model

from model import Transformer, ModelArgs


def sample(logits, temperature: float = 1.0):
    """
    Samples a token from the logits using temperature scaling.

    Args:
        logits (torch.Tensor): The logits tensor for token predictions.
        temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.

    Returns:
        torch.Tensor: The sampled token.
    """
    logits = logits / max(temperature, 1e-5)
    probs = torch.softmax(logits, dim=-1)
    return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)


@torch.inference_mode()
def generate(
    model: Transformer,
    prompt_tokens: List[List[int]],
    max_new_tokens: int,
    eos_id: int,
    temperature: float = 1.0
) -> List[List[int]]:
    """
    Generates new tokens based on the given prompt tokens using the specified model.

    Args:
        model (Transformer): The transformer model used for token generation.
        prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
        max_new_tokens (int): The maximum number of new tokens to generate.
        eos_id (int): The end-of-sequence token ID.
        temperature (float, optional): The temperature value for sampling. Defaults to 1.0.

    Returns:
        List[List[int]]: A list of lists containing the generated tokens for each sequence.
    """
    prompt_lens = [len(t) for t in prompt_tokens]
    assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
    total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
    tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
    for i, t in enumerate(prompt_tokens):
        tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
    prev_pos = 0
    finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
    prompt_mask = tokens != -1
    for cur_pos in range(min(prompt_lens), total_len):
        logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
        if temperature > 0:
            next_token = sample(logits, temperature)
        else:
            next_token = logits.argmax(dim=-1)
        next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
        tokens[:, cur_pos] = next_token
        finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
        prev_pos = cur_pos
        if finished.all():
            break
    completion_tokens = []
    for i, toks in enumerate(tokens.tolist()):
        toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
        if eos_id in toks:
            toks = toks[:toks.index(eos_id)]
        completion_tokens.append(toks)
    return completion_tokens


def main(
    ckpt_path: str,
    config: str,
    input_file: str = "",
    interactive: bool = True,
    max_new_tokens: int = 100,
    temperature: float = 1.0,
) -> None:
    """
    Main function to load the model and perform interactive or batch text generation.

    Args:
        ckpt_path (str): Path to the model checkpoint directory.
        config (str): Path to the model configuration file.
        input_file (str, optional): Path to a file containing input prompts. Defaults to "".
        interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
        max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
        temperature (float, optional): Temperature for sampling. Defaults to 1.0.
    """
    world_size = int(os.getenv("WORLD_SIZE", "1"))
    rank = int(os.getenv("RANK", "0"))
    local_rank = int(os.getenv("LOCAL_RANK", "0"))
    if world_size > 1:
        dist.init_process_group("nccl")
    global print
    if rank != 0:
        print = lambda *_, **__: None
    torch.cuda.set_device(local_rank)
    torch.set_default_dtype(torch.bfloat16)
    torch.set_num_threads(8)
    torch.manual_seed(965)
    with open(config) as f:
        args = ModelArgs(**json.load(f))
    print(args)
    with torch.device("cuda"):
        model = Transformer(args)
    tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
    tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
    load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))

    if interactive:
        messages = []
        while True:
            if world_size == 1:
                prompt = input(">>> ")
            elif rank == 0:
                prompt = input(">>> ")
                objects = [prompt]
                dist.broadcast_object_list(objects, 0)
            else:
                objects = [None]
                dist.broadcast_object_list(objects, 0)
                prompt = objects[0]
            if prompt == "/exit":
                break
            elif prompt == "/clear":
                messages.clear()
                continue
            messages.append({"role": "user", "content": prompt})
            prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
            completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
            completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
            print(completion)
            messages.append({"role": "assistant", "content": completion})
    else:
        with open(input_file) as f:
            prompts = [line.strip() for line in f.readlines()]
        assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
        prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
        completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
        completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
        for prompt, completion in zip(prompts, completions):
            print("Prompt:", prompt)
            print("Completion:", completion)
            print()

    if world_size > 1:
        dist.destroy_process_group()


if __name__ == "__main__":
    """
    Command-line interface for distributed text generation.

    Arguments:
        --ckpt-path (str): Path to the model checkpoint directory.
        --config (str): Path to the model configuration file.
        --input-file (str, optional): File containing prompts for batch processing.
        --interactive (bool, optional): Enable interactive mode for generating text.
        --max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
        --temperature (float, optional): Temperature for sampling. Defaults to 0.2.

    Raises:
        AssertionError: If neither input-file nor interactive mode is specified.
    """
    parser = ArgumentParser()
    parser.add_argument("--ckpt-path", type=str, required=True)
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--input-file", type=str, default="")
    parser.add_argument("--interactive", action="store_true")
    parser.add_argument("--max-new-tokens", type=int, default=200)
    parser.add_argument("--temperature", type=float, default=0.2)
    args = parser.parse_args()
    assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
    main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)

这段代码可以拆分为以下几个核心模块进行解析:


1. 依赖导入

import os
import json
from argparse import ArgumentParser
from typing import List

import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model

from model import Transformer, ModelArgs

作用

  • os & json:用于读取环境变量和 JSON 配置文件。
  • argparse:解析命令行参数。
  • typing.List:定义列表类型的输入和输出。
  • torch & torch.distributed
    • torch 用于加载模型、处理张量运算。
    • torch.distributed 支持分布式训练(如多 GPU 计算)。
  • transformers.AutoTokenizer:加载分词器,将文本转换为 token。
  • safetensors.torch.load_model:加载模型参数(比传统的 .pt 更安全)。
  • model.Transformer, ModelArgs:自定义的 Transformer 模型及其配置。

2. 采样函数

def sample(logits, temperature: float = 1.0):
    """
    通过温度参数对 logits 进行采样,以生成下一个 token。

    Args:
        logits (torch.Tensor): 预测的 logits。
        temperature (float): 采样温度,决定随机性。

    Returns:
        torch.Tensor: 选出的 token。
    """
    logits = logits / max(temperature, 1e-5)  # 避免除零
    probs = torch.softmax(logits, dim=-1)  # 转换为概率分布
    return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)

作用

  • 温度控制:通过 temperature 影响采样随机性。
    • temperature 高 → 采样更随机(创造性更强)。
    • temperature 低 → 采样更确定(趋向于选取最大概率)。
  • softmax 转换 logits 为概率分布。
  • 采用 Gumbel-Softmax 采样(避免直接 argmax 选择最大概率 token)。

3. 文本生成函数

@torch.inference_mode()
def generate(
    model: Transformer,
    prompt_tokens: List[List[int]],
    max_new_tokens: int,
    eos_id: int,
    temperature: float = 1.0
) -> List[List[int]]:

作用

  • 根据 prompt_tokens 生成 max_new_tokens 数量的 token。
  • 参数说明
    • model:Transformer 模型。
    • prompt_tokens:输入的 token 列表(批量)。
    • max_new_tokens:最大新生成 token 数量。
    • eos_id:终止 token ID(例如 </s>)。
    • temperature:控制采样的随机性。

3.1. 预处理

prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len, f"超出模型最大长度 {model.max_seq_len}"

total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens):
    tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
  • 计算 prompt 长度,确保不超过 model.max_seq_len
  • 生成 填充 token 张量(-1 表示未填充)。
  • prompt_tokens 复制到 tokens 中。

3.2. 生成过程

prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
prompt_mask = tokens != -1
  • prev_pos:当前生成的起始位置。
  • finished:跟踪哪些序列已生成 eos_id
  • prompt_mask:指示哪些 token 是 prompt(不可修改)。

for cur_pos in range(min(prompt_lens), total_len):
    logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)

    if temperature > 0:
        next_token = sample(logits, temperature)
    else:
        next_token = logits.argmax(dim=-1)  # 直接选最大概率 token

    next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
    tokens[:, cur_pos] = next_token

    finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)

    prev_pos = cur_pos

    if finished.all():
        break
  • 依次生成 token:
    • 获取 logits 并采样。
    • 确保 prompt 部分不被修改
    • 如果遇到 eos_id,标记 finished
    • 提前终止:所有序列都完成时,结束循环。

4. 主函数

def main(ckpt_path, config, input_file="", interactive=True, max_new_tokens=100, temperature=1.0):

作用

  • 读取模型,支持 交互式批量处理

4.1. 分布式计算

world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))

if world_size > 1:
    dist.init_process_group("nccl")

torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16)
torch.manual_seed(965)
  • 分布式训练
    • WORLD_SIZE > 1 时,初始化 NCCL 后端进行分布式计算。
    • rank != 0 的进程屏蔽 print(只让 rank=0 负责打印)。
    • bfloat16:减少显存占用,加速计算。

4.2. 加载模型

with open(config) as f:
    args = ModelArgs(**json.load(f))

with torch.device("cuda"):
    model = Transformer(args)
  • 解析 config 并实例化 Transformer 模型。
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
  • 通过 safetensors 加载模型参数。

4.3. 交互式模式

if interactive:
    messages = []
    while True:
        prompt = input(">>> ")
        if prompt == "/exit":
            break
        elif prompt == "/clear":
            messages.clear()
            continue
        messages.append({"role": "user", "content": prompt})
        prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
        completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
        completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
        print(completion)
        messages.append({"role": "assistant", "content": completion})
  • 支持对话,并记住历史消息。
  • apply_chat_template:将 messages 转换为输入格式。
  • 生成并打印 completion

5. 入口点

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--ckpt-path", type=str, required=True)
    parser.add_argument("--config", type=str, required=True)
    ...
    args = parser.parse_args()
    main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
  • 解析命令行参数,调用 main()

总结

这段代码的核心是:

  1. 分布式环境初始化
  2. 加载模型
  3. 提供交互和批量处理
  4. 通过 Transformer 生成文本

这个 Python 脚本是一个用于文本生成的程序,结合了 PyTorch 和 Hugging Face Transformers 库,并支持分布式计算。它的主要功能是:

  1. 加载 Transformer 模型

    • 使用 safetensors 加载权重。
    • 通过 AutoTokenizer 处理文本输入和输出。
  2. 生成文本

    • generate 函数使用 Transformer 模型,根据输入的 token 生成新的 token。
    • 支持温度采样 (temperature) 控制随机性,支持终止符 (eos_id) 结束生成。
  3. 支持交互式模式和批量模式

    • 交互式模式下,用户可以输入文本,模型进行续写,并可清空对话历史。
    • 批量模式下,程序从 input_file 读取多个输入,生成对应的文本。
  4. 分布式计算

    • 支持 PyTorch torch.distributed 进行多 GPU 训练。
    • 通过 os.getenv("WORLD_SIZE") 获取集群大小,初始化 nccl 后端。
  5. 命令行参数

    • 通过 argparse 解析参数,如 --ckpt-path 指定模型路径,--config 指定配置文件等。

可能的问题:

  • generate 方法可能会因 prev_pos 的更新方式导致索引超出界限。
  • tokens[:, cur_pos] = next_token 可能在 cur_pos 超出 tokens.shape[1] 时引发索引错误。
  • logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) 假设 forward 方法能处理 prev_pos 这种输入,但具体实现未知,可能需要调整。
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐