前言

在上篇文章 斯坦福大学 | CS336 | 从零开始构建语言模型 | Spring 2025 | 笔记 | Assignment 1: BPE Tokenizer 中,我们已经了解了 BPE Tokenizer 的作业要求,下面我们就一起来看看这些作业该如何实现,本篇文章记录 CS336 作业 Assignment 1: Basics 中的 BPE Tokenizer 实现,仅供自己参考😄

Note:博主并未遵循 from-scratch 的宗旨,所有代码几乎均由 ChatGPT 完成

Assignment 1https://github.com/stanford-cs336/assignment1-basics

referencehttps://chatgpt.com/

referencehttps://github.com/donglinkang2021/cs336-assignment1-basics

referencehttps://github.com/Louisym/Stanford-CS336-spring25

1. Problem (unicode1): Understanding Unicode (1 point)

(a) chr(0) 返回的是哪个 Unicode 字符?

Deliverablechr(0) 返回的是 Unicode 空字符(NUL,U+0000)。

(b) 该字符的字符串表示(__repr__())与其打印结果(print)有什么不同?

>>> chr(0)
'\x00'
>>> print(chr(0))

>>> 

Deliverable它的 repr 会显示成可见的转义形式 '\x00',而 print(chr(0)) 打印出来通常看起来像什么都没输出,因为这是不可见控制字符。

(c) 当该字符出现在文本中时会发生什么?你可以在 Python 解释器中尝试以下代码,看看是否符合你的直觉:

>>> chr(0)
>>> print(chr(0))
>>> "this is a test" + chr(0) + "string"
>>> print("this is a test" + chr(0) + "string")
>>> chr(0)
'\x00'
>>> print(chr(0))

>>> "this is a test" + chr(0) + "string"
'this is a test\x00string'
>>> print("this is a test" + chr(0) + "string")
this is a teststring

Deliverable当它出现在文本中时,它会作为字符串中的一个真实字符参与拼接与长度计算,但在很多终端或显示环境里不可见或被当作 “空” 处理,所以 print("this is a test" + chr(0) + "string") 往往看起来就像把两段直接连起来一样。

2. Problem (unicode2):Unicode Encodings (3 points)

(a) 为什么在 tokenizer 训练中更倾向使用 UTF-8,而不是 UTF-16 / UTF-32?

Deliverable因为 UTF-8 是单字节表示,在处理英文和 ASCII 字符时更高效,同时又能兼容所有 Unicode 字符;而 UTF-16 / UTF-32 对基本字符也使用多字节,会引入不必要的冗余,增加 token 序列长度和计算开销。

(b) 下列 UTF-8 解码函数是错误的,为什么?请给出一个会产生错误结果的输入示例,并解释原因。

def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
    return "".join([bytes([b]).decode("utf-8") for b in bytestring])

>>> decode_utf8_bytes_to_str_wrong("hello".encode("utf-8"))
'hello'
>>> decode_utf8_bytes_to_str_wrong("é".encode("utf-8"))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
    decode_utf8_bytes_to_str_wrong("é".encode("utf-8"))
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^
  File "<stdin>", line 2, in decode_utf8_bytes_to_str_wrong
    return "".join([bytes([b]).decode("utf-8") for b in bytestring])
                    ~~~~~~~~~~~~~~~~~^^^^^^^^^
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xc3 in position 0: unexpected end of data
>>> "é".encode("utf-8")
b'\xc3\xa9'

Deliverable"é" 的 UTF-8 是多字节序列(0xC3 0xA9),该解码函数把每个字节单独当成一个 UTF-8 字符去 decode,会在遇到 0xC3 这种不完整的起始字节时报错或解码错误,因为 UTF-8 必须按完整字节序列一起解码。

(c) 给出一个 无法解码为任何 Unicode 字符的 2-byte 序列

b"\xC0\xAF"
>>> b"\xC0\xAF".decode("utf-8")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
    b"\xC0\xAF".decode("utf-8")
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xc0 in position 0: invalid start byte

Deliverable这是 UTF-8 的 overlong 编码,属于标准明确禁止的非法序列,因此不能解码成任何 Unicode 字符。


Q & A

Q:我想知道 UTF-8 的 overlong 编码是指什么?为什么它被禁止?

A:相关内容可以查看 [Wikipedia] 上的说明,具体来说,overlong 编码指的是用 “比规范要求更多的字节数” 去编码一个 Unicode 码点,而 UTF-8 标准明确规定,每个码点只能用最短的那种字节形式编码。例如,b"\xC0\xAF" 按二进制(11000000 10101111)拆分组合计算码点时是 0x2F(十进制 47),即字符 /,而字符 / 本来应该用一个字节编码(0x2F),现在却被编码成了 2 字节,用了更长的编码方式。禁止 overlong 编码主要是为了防止安全漏洞,如用 0xC0 0xAF 绕过 / 的过滤,同时确保每个字符只有一种合法编码避免歧义,所有现在任何 UFT-8 编码器,只要遇到 overlong 编码,必须直接报错。


3. Problem (train_bpe): BPE Tokenizer Training (15 points)

Deliverable:请编写一个函数:给定一个输入文本文件的路径,用于训练一个 字节级 BPE 分词器,你的 BPE 训练函数至少需要支持以下输入参数“

  • input_path: str:指向包含 BPE 分词器训练数据的文本文件路径
  • vocab_size: int:一个正整数,用于指定最终词表的最大大小,包括初始字节词表、合并过程中生成的词表项以及所有特殊 token
  • special_tokens: list[str]:需要加入词表的字符串列表,这些特殊 token 不会以其他方式影响 BPE 的训练过程

你的 BPE 训练函数应当返回训练得到的 词表合并规则

  • vocab: dict[int, bytes]:分词器的词表,一个从 int(词表中的 token ID)到 bytes(token 对应的字节序列)的映射
  • merges: list[tuple[bytes, bytes]]:训练过程中产生的 BPE 合并列表,列表中的每一项是一个由 bytes 组成的二元组 (<token1>, <token2>),表示 <token1><token2> 被合并,合并规则应当按照 创建顺序 排列

为了使用我们提供的测试用例来验证你的 BPE 训练函数,你需要先实现测试适配器 [adapters.run_train_bpe],然后运行:

uv run pytest tests/test_train_bpe.py

你的实现应当能够通过所有测试

3.1 基础实现

train_bpe.py 实现代码如下:

import os

# GPT-2-style regex pre-tokenizer (requires third-party `regex`)
GPT2_PRETOKENIZE_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
GPT2_REGEX = None

def get_gpt2_regex():
    global GPT2_REGEX
    if GPT2_REGEX is None:
        import regex as re

        GPT2_REGEX = re.compile(GPT2_PRETOKENIZE_PATTERN)
    return GPT2_REGEX

def train_bpe(
    input_path: str | os.PathLike,
    vocab_size: int,
    special_tokens: list[str],
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    if vocab_size <= 0:
        raise ValueError("vocab_size must be a positive integer")
    if vocab_size < 256 + len(special_tokens):
        raise ValueError("vocab_size too small: must be >= 256 + len(special_tokens)")

    # ---- vocab init: 256 single-byte tokens + special tokens ----
    vocab: dict[int, bytes] = {i: bytes([i]) for i in range(256)}
    next_id = 256
    for token in special_tokens:
        vocab[next_id] = token.encode("utf-8")
        next_id += 1
    
    # ---- pre-tokenization + counting (special tokens are hard boundaries) ----
    rx = get_gpt2_regex()
    word_freq: dict[tuple[bytes, ...], int] = {}

    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()
    if not text:
        return vocab, []

    # split by each special token; drop the special tokens from training stats
    spans = [text]
    for s_tok in special_tokens:
        new_spans: list[str] = []
        for sp in spans:
            if sp:
                new_spans.extend(sp.split(s_tok))
        spans = new_spans
    
    for sp in spans:
        if not sp:
            continue
        for m in rx.finditer(sp):
            piece = m.group(0)
            if not piece:
                continue
            bts = piece.encode("utf-8")
            key = tuple(bytes([b]) for b in bts)
            word_freq[key] = word_freq.get(key, 0) + 1
    
    # ---- BPE merges ----
    merges: list[tuple[bytes, bytes]] = []
    while next_id < vocab_size:
        # count pair frequenices
        pair_counts: dict[tuple[bytes, bytes], int] = {}
        for word, freq in word_freq.items():
            if len(word) < 2:
                continue
            prev = word[0]
            for cur in word[1:]:
                pair = (prev, cur)
                pair_counts[pair] = pair_counts.get(pair, 0) + freq
                prev = cur
        
        if not pair_counts:
            break
    
        # choose most frequent; tie-break by lexicographically largest pair
        (a, b), best_count = max(pair_counts.items(), key=lambda kv: (kv[1], kv[0]))
        if best_count <= 0:
            break
        
        new_token = a + b
        merges.append((a, b))
        vocab[next_id] = new_token
        next_id += 1

        # replace occurrences of (a,b) in every word
        new_word_freq: dict[tuple[bytes, ...], int] = {}
        for word, freq in word_freq.items():
            if len(word) < 2:
                new_word_freq[word] = new_word_freq.get(word, 0) + freq
                continue
            
            merged: list[bytes] = []
            i = 0
            L = len(word)
            while i < L:
                if i < L - 1 and word[i] == a and word[i + 1] == b:
                    merged.append(new_token)
                    i += 2
                else:
                    merged.append(word[i])
                    i += 1
            
            key = tuple(merged)
            new_word_freq[key] = new_word_freq.get(key, 0) + freq
        
        word_freq = new_word_freq
    
    return vocab, merges

测试适配器 [adapters.run_train_bpe] 的实现如下:

def run_train_bpe(
    input_path: str | os.PathLike,
    vocab_size: int,
    special_tokens: list[str],
    **kwargs,
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """Given the path to an input corpus, run train a BPE tokenizer and
    output its vocabulary and merges.

    Args:
        input_path (str | os.PathLike): Path to BPE tokenizer training data.
        vocab_size (int): Total number of items in the tokenizer's vocabulary (including special tokens).
        special_tokens (list[str]): A list of string special tokens to be added to the tokenizer vocabulary.
            These strings will never be split into multiple tokens, and will always be
            kept as a single token. If these special tokens occur in the `input_path`,
            they are treated as any other string.

    Returns:
        tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
            vocab:
                The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
                to bytes (token bytes)
            merges:
                BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
                representing that <token1> was merged with <token2>.
                Merges are ordered by order of creation.
    """
    from cs336_basics.train_bpe import train_bpe
    
    return train_bpe(
        input_path=input_path,
        vocab_size=vocab_size,
        special_tokens=special_tokens
    )

运行指令如下:

# 1. install uv (via pip)
conda activate base
pip install uv

# set url (you might use it)
# uv sync --index-url https://pypi.tuna.tsinghua.edu.cn/simple

# 2. run
uv run pytest tests/test_train_bpe.py

执行 uv run pytest tests/test_train_bpe.py 后输出如下:

在这里插入图片描述

下面我们来简单看看代码具体是如何实现的:

step 1. vocab init

vocab: dict[int, bytes] = {i: bytes([i]) for i in range(256)}
next_id = 256
for token in special_tokens:
    vocab[next_id] = token.encode("utf-8")
    next_id += 1

在这里插入图片描述

vocab 词表是一个字典类型的变量,其 key 表示 token id(整数),其 value 表示这个 token 对应的 字节序列,这个字典中最先存储的是 base vocabulary 也就是所有可能的单字节

next_id 代表当前新 token 应该分配的下一个 id,由于初始的 0~255 被用掉了,所以从 256 开始往后加,先加函数传参进来的 special tokens,然后再加 merge 产生新的 token

step 2. pre-tokenization

rx = get_gpt2_regex()
word_freq: dict[tuple[bytes, ...], int] = {}

在这里插入图片描述

原始的 BPE 实现中预分词方式只是简单地按空白字符进行分割,这里我们使用一种基于正则表达式的预分词器(GPT-2 所采用),具体来说我们会通过 get_gpt2_regex() 拿到 GPT-2 的预分词器,它的作用是把一大段文本切成一段段 piece(片段),减少 BPE merges 阶段的计算开销

word_freq 词频是一个字典类型的变量,其 key 是一个元组类型,表示一个预分词片段被拆成 token 序列后的表示,该元组中的每一个元素都是一个 token;其 value 表示这个片段在语料库中出现的次数,这个字典最终会存储 训练语料中出现过的片段 token 序列及其频次

Note:训练 BPE 时,我们不是直接在原始字符串上统计 pair,而是在这些 word_freq 的 key(token 序列)上统计 pair

with open(input_path, "r", encoding="utf-8") as f:
    text = f.read()
if not text:
    return vocab, []

在这里插入图片描述

接着通过 f.read() 获取完整文本

spans = [text]
for s_tok in special_tokens:
    new_spans = []
    for sp in spans:
        if sp:
            new_spans.extend(sp.split(s_tok))
    spans = new_spans

在这里插入图片描述

然后将文本按照 special token 切分开来,spans 是一个字符串列表,表示把原始 text 按所有 special token 切开后得到的普通文本片段

比如 text = "A<|endoftext|>B" 且 special token 是 <|endoftext|>,那么 split 后的 spans = ["A", "B"],special token 本身被 直接丢掉,不会进入后续统计

可以看到当前文本按照 special token 切开后有 6488 个文本片段,处理完之后的 spans 就拥有这些切分好的片段

for sp in spans:
    if not sp:
        continue
    for m in rx.finditer(sp):
        piece = m.group(0)
        if not piece:
            continue
        bts = piece.encode("utf-8")
        key = tuple(bytes([b]) for b in bts)
        word_freq[key] = word_freq.get(key, 0) + 1

在这里插入图片描述

在这里插入图片描述

最后通过 rx.finditer(sp) 对每个 span 做 regex 预分词,其中:

  • sp 是不含 special token 的普通片段
  • rx.finditer(sp) 回从头到尾找到一段段匹配
  • 每个 m 是一个 match,m.group(0) 就是匹配到的字符串片段(piece)

举个简单的例子,假设 sp = " hello, world!\n\n",那它可能切成的 piece 序列大概像:" hello"","" world""!""\n\n",具体拆分取决于 regex 的规则,但大致如此

针对每个字符串片段 piece,我们使用 encode 函数将其变为字节串 bytes,例如 " hello".encode("utf-8") 经过 encode 之后就变成了 b" hello"

key = tuple(bytes([b]) for b in bts) 这一步是把字节串拆成单字节 token 序列,bts 是我们上一步 encode 得到的字节串,例如 b" hello",我们现在把它拆成 (b' ', b'h', b'e', b'l', 'l', 'o')

最后累计每个 piece 的出现次数,其中 word_freq.get(key, 0) 是字典的方法,用于安全地获取值,其语法是 dict.get(key, defalt),尝试获取 key 对应的值,如果 key 不存在,则返回默认值 0

OK,那到目前为止我们完成了词表初始化和预分词两个步骤,得到了以下关键数据:

vocab: dict[int, bytes]

  • 当前词表:token_id -> token bytes 的映射
  • 已包含:单字节 + special tokens
  • 后续每次 merge 会继续里加新 token

next_id

  • 下一个新 token 的 id(初始 256 + special_tokens 个数)
  • 每做一个 merge,都会 vocab[next_id] = new_token; next_id += 1

word_freq: dict[tuple[bytes, ...], int]

  • 训练数据的压缩表示
  • key:一个 piece 被表示成的 token 序列
  • value:这个序列在语料中出现的次数
  • 后续 merge 会把 key 里的 token 序列更新(把 (a,b) 替换成 new_token)

在进入 BPE merges 主循环时,我们需要用到以上数据,那接下来我们就开始 BPE 算法最核心的步骤 BPE merge 吧

step 3. BPE merge

merges: list[tuple[bytes, bytes]] = []

初始化 merges 列表,merges 是一个列表对象,其作用是用来记录每一步合并规则 (a, b),后续 tokenizer encode 会按这个顺序应用 merge

while next_id < vocab_size:

定义外层 while 训练循环的终止条件,只要还有空位能往 vocab 里加新 token 就继续训练

pair_counts: dict[tuple[bytes, bytes], int] = {}
for word, freq in word_freq.items():
    if len(word) < 2:
        continue
    prev = word[0]
    for cur in word[1:]:
        pair = (prev, cur)
        pair_counts[pair] = pair_counts.get(pair, 0) + freq
        prev = cur

在这里插入图片描述

第一轮 for 循环,统计所有 “相邻 token 对” 的频次即 pair_counts

word 是一个 token 序列,它来自预分词阶段的 piece(或经过若干轮 merge 后的更新版本),例如 (b' ', b'h', b'e', b'l', b'l', b'o'),而 freq 表示的则是这个 token 序列在语料中出现了多少次,例如 " hello" 出现 50 次,那么它内部每个相邻 pair 都贡献 50 次

pair_counts[(x, y)] 表示 token 对 (x, y) 在整个训练语料中出现的总次数,例如,如果 (b'h', b'e') 在很多序列里出现,共计加权出现 1000 次,那么 pair_counts[(b'h', b'e')] = 1000

if not pair_counts:
    break

(a, b), best_count = max(pair_counts.items(), key=lambda kv: (kv[1], kv[0]))
if best_count <= 0:
    break

在这里插入图片描述

如果没有相邻对可合并,训练到此结束,直接 break 即可

通过 max 函数优先选频次最高的 pair,如果频次一样,则比较 pair 本身(即 kv[0]),kv[0](bytes, bytes),python 会按字典序比较,从而实现我们想要的频次相同就选字典序更大的 pair

new_token = a + b
merges.append((a, b))
vocab[next_id] = new_token
next_id += 1

在这里插入图片描述

统计完最高频次的 pair 对之后我们就可以来创建新的 token 了,把它写入到 vocab 中,并记录下 merges

# replace occurrences of (a,b) in every word
new_word_freq: dict[tuple[bytes, ...], int] = {}
for word, freq in word_freq.items():
    if len(word) < 2:
        new_word_freq[word] = new_word_freq.get(word, 0) + freq
        continue
    
    merged: list[bytes] = []
    i = 0
    L = len(word)
    while i < L:
        if i < L - 1 and word[i] == a and word[i + 1] == b:
            merged.append(new_token)
            i += 2
        else:
            merged.append(word[i])
            i += 1
    
    key = tuple(merged)
    new_word_freq[key] = new_word_freq.get(key, 0) + freq

word_freq = new_word_freq

在这里插入图片描述

在这里插入图片描述

添加完新 token 之后,我们还需要把所有 word 里面的 (a, b) 替换成 new_token,得到新的 word_freq。这是因为我们已经决定 (a, b) 现在视为一个新 token,下一轮统计 pair 时必须基于新 token 序列,否则我们会重复合并相同的局部结构,训练不会推进

首先我们会定义一个新的 new_word_freq,接着会遍历旧的 word_freq,如果长度小于 2 的序列(不可能包含 pair 相邻对)直接原样搬运过去

对于长度大于 2 的序列,则开始做 while 扫描:如果当前位置 word[i]word[i+1] 恰好是 (a,b) 则把这两个 token 合成一个 new_token,指针跳过两个位置;否则当前 token 不变,直接放入 merged,指针前进一位

随后把 merged 转回 tuple 作为字典,并累计频次

我们举个具体例子走一遍就清楚了,假设这一轮我们选到了 a = b'h' 以及 b = b'e' 作为我们的 new_token(b'he') ,并且旧的 word_freq 里有一条 token 序列为:

word = (b'h', b'e', b'l', b'l', b'o')
freq = 3

那么经过 while 扫描之后最终我们得到:

merged = [b'he', b'l', b'l', b'o']
key = (b'he', b'l', b'l', b'o')
new_word_freq[key] += 3

在下一轮训练里,"hello" 这个片段不再是 h e l l o,而是 he l l o

最后,我们返回得到的词表 vocab 和合并规则 merges 即可:

return vocab, merges

至此,我们完整的实现了 BPE Tokenizer 的训练并通过了相关测试

下面我们根据作业中提到的一些 Tips 来看看这个实现有没有进一步优化的空间

3.2 优化 (Parallelizing pre-tokenization)

在作业要求的 2.5 小节中提到 Parallelizing pre-tokenization(并行化预分词)的优化策略:预分词步骤是整个流程中的一个主要性能瓶颈,我们可以通过 Python 内置的 multiprocessing 库对代码进行并行化,从而加速预分词过程。

代码实现如下:

import os
from collections import Counter
from multiprocessing import Pool
from cs336_basics.pretokenization_example import find_chunk_boundaries

# GPT-2-style regex pre-tokenizer (requires third-party `regex`)
GPT2_PRETOKENIZE_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

# Each worker process compiles to its own regex instance
RX = None

def init_worker():
    global RX
    import regex as re
    RX = re.compile(GPT2_PRETOKENIZE_PATTERN)

def count_word_freq_from_text(text: str, special_tokens: list[str]) -> dict[tuple[bytes, ...], int]:
    """
    Perform pre-tokenization on a text chunk:
      1) Split on special tokens (special tokens do not participate in training stats)
      2) Apply GPT-2 regex pre-tokenization
      3) Encode each piece to UTF-8 bytes and split into single-byte tokens
      4) Count token-sequence frequencies
    """
    if not text:
        return {}

    # split by each special token; drop the special tokens from training stats
    spans = [text]
    for s_tok in special_tokens:
        new_spans: list[str] = []
        for sp in spans:
            if sp:
                new_spans.extend(sp.split(s_tok))
        spans = new_spans
    
    word_freq: dict[tuple[bytes, ...], int] = {}
    for sp in spans:
        if not sp:
            continue
        for m in RX.finditer(sp):
            piece = m.group(0)
            if not piece:
                continue
            bts = piece.encode("utf-8")
            key = tuple(bytes([b]) for b in bts)
            word_freq[key] = word_freq.get(key, 0) + 1
    return word_freq

def process_chunk(args) -> dict[tuple[bytes, ...], int]:
    """
    Worker entry point for processing a single file chunk.
    """
    input_path, start, end, special_tokens = args
    with open(input_path, "rb") as f:
        f.seek(start)
        chunk = f.read(end - start)

    # Decode with errors ignored to avoid UTF-8 boundary issues
    text = chunk.decode("utf-8", errors="ignore")
    return count_word_freq_from_text(text, special_tokens)

def build_word_freq_serial(input_path : str | os.PathLike, special_tokens: list[str]) -> dict[tuple[bytes, ...], int]:
    init_worker()
    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()
    return count_word_freq_from_text(text, special_tokens)

def build_word_freq_parallel(
    input_path: str | os.PathLike,
    special_tokens: list[str],
    num_processes: int
) -> dict[tuple[bytes, ...], int]:
    """
    Build word frequency statistics using multiprocessing.
    Chunk boundaries are aligned to special-token boundaries.    
    """
    if num_processes <= 1 or not special_tokens:
        return build_word_freq_serial(input_path, special_tokens)
    
    split_special_token = special_tokens[0].encode("utf-8")  # e.g. b"<|endoftext|>"
    with open(input_path, "rb") as f:
        boundaries = find_chunk_boundaries(f, num_processes, split_special_token)
    
    tasks = [(str(input_path), s, e, special_tokens) for s, e in zip(boundaries[:-1], boundaries[1:])]

    merged = Counter()
    with Pool(processes=num_processes, initializer=init_worker) as pool:
        # 
        for partial in pool.imap_unordered(process_chunk, tasks, chunksize=1):
            merged.update(partial)
    
    return dict(merged)

def train_bpe(
    input_path: str | os.PathLike,
    vocab_size: int,
    special_tokens: list[str],
    *,
    num_processes: int | None = None
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """
    Train a byte-level BPE tokenizer.

    Returns:
        vocab: dict[int, bytes]
        merges: list[tuple[bytes, bytes]]    
    """
    if vocab_size <= 0:
        raise ValueError("vocab_size must be a positive integer")
    if vocab_size < 256 + len(special_tokens):
        raise ValueError("vocab_size too small: must be >= 256 + len(special_tokens)")

    # ---- vocab init: 256 single-byte tokens + special tokens ----
    vocab: dict[int, bytes] = {i: bytes([i]) for i in range(256)}
    next_id = 256
    for token in special_tokens:
        vocab[next_id] = token.encode("utf-8")
        next_id += 1
    
    # ---- pre-tokenization + counting (parallelization) ----
    if num_processes is None:
        num_processes = min(8, os.cpu_count() or 1)
    
    file_size = os.path.getsize(input_path)

    # # For small files, multiprocessing overhead dominates; use serial
    if num_processes <= 1 or file_size < 1_000_000:  # ~1MB
        word_freq = build_word_freq_serial(input_path, special_tokens)
    else:
        word_freq = build_word_freq_parallel(input_path, special_tokens, num_processes)
    
    if not word_freq:
        return vocab, []

    # ---- BPE merges ----
    merges: list[tuple[bytes, bytes]] = []
    while next_id < vocab_size:
        # count pair frequenices
        pair_counts: dict[tuple[bytes, bytes], int] = {}
        for word, freq in word_freq.items():
            if len(word) < 2:
                continue
            prev = word[0]
            for cur in word[1:]:
                pair = (prev, cur)
                pair_counts[pair] = pair_counts.get(pair, 0) + freq
                prev = cur
        
        if not pair_counts:
            break
    
        # choose most frequent; tie-break by lexicographically largest pair
        (a, b), best_count = max(pair_counts.items(), key=lambda kv: (kv[1], kv[0]))
        if best_count <= 0:
            break
        
        new_token = a + b
        merges.append((a, b))
        vocab[next_id] = new_token
        next_id += 1

        # replace occurrences of (a,b) in every word
        new_word_freq: dict[tuple[bytes, ...], int] = {}
        for word, freq in word_freq.items():
            if len(word) < 2:
                new_word_freq[word] = new_word_freq.get(word, 0) + freq
                continue
            
            merged: list[bytes] = []
            i = 0
            L = len(word)
            while i < L:
                if i < L - 1 and word[i] == a and word[i + 1] == b:
                    merged.append(new_token)
                    i += 2
                else:
                    merged.append(word[i])
                    i += 1
            
            key = tuple(merged)
            new_word_freq[key] = new_word_freq.get(key, 0) + freq
        
        word_freq = new_word_freq
    
    return vocab, merges

相比于基础实现的串行版本,上面的并行化预分词版本的主要差别就三点:

1. 预分词统计从 “全文串行” 变成了 “可并行的分块统计”

以前是 text = f.read() 然后在一个进程里 rx.finditer(text) 统计 word_freq;现在把语料按 <|endoftext|> 的边界用 find_chunk_boundaries 切成多个 chunk,每个 chunk 在独立 worker 里做同样的 regex 预分词与计数,最后用 Counter().update() 汇总成全局 word_freq,这就是 Parallelizing pre-tokenization

2. regex 编译策略从 “主进程单例” 变成 “每个 worker 自己 compile”

旧版用 get_gpt2_regex() 在主进程 compile 一个 GPT2_REGEX;新版为了 multiprocessing 安全与性能,用 Pool(initializer=init_worker) 让每个子进程在启动时各自 compile 一份 RX,避免把 regex 对象在进程间传递/重复编译

3. 增加了 “小文件走串行” 的性能分支

多进程有启动和 IPC 开销,小语料反而更慢,所以我们加了当文件 size 小于大约 1M 时仍然采用串行

其余核心逻辑如 vocab 初始化、pair_counts 统计、BPE merge 基本保持不变

3.3 优化 (Optimizing the merging step)

在作业要求的 2.5 小节中还提到可以优化 BPE merge 步骤,通过为所有字节对建立索引,并在合并后对相关计数进行 增量更新,而不是每次都显式地重新统计所有字节对,从而显著提升 BPE 训练速度

代码实现如下:

import os
from collections import Counter
from multiprocessing import Pool
from cs336_basics.pretokenization_example import find_chunk_boundaries

# GPT-2-style regex pre-tokenizer (requires third-party `regex`)
GPT2_PRETOKENIZE_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

# Each worker process compiles to its own regex instance
RX = None

def init_worker():
    global RX
    import regex as re
    RX = re.compile(GPT2_PRETOKENIZE_PATTERN)

def count_word_freq_from_text(text: str, special_tokens: list[str]) -> dict[tuple[bytes, ...], int]:
    """
    Perform pre-tokenization on a text chunk:
      1) Split on special tokens (special tokens do not participate in training stats)
      2) Apply GPT-2 regex pre-tokenization
      3) Encode each piece to UTF-8 bytes and split into single-byte tokens
      4) Count token-sequence frequencies
    """
    if not text:
        return {}

    # split by each special token; drop the special tokens from training stats
    spans = [text]
    for s_tok in special_tokens:
        new_spans: list[str] = []
        for sp in spans:
            if sp:
                new_spans.extend(sp.split(s_tok))
        spans = new_spans
    
    word_freq: dict[tuple[bytes, ...], int] = {}
    for sp in spans:
        if not sp:
            continue
        for m in RX.finditer(sp):
            piece = m.group(0)
            if not piece:
                continue
            bts = piece.encode("utf-8")
            key = tuple(bytes([b]) for b in bts)
            word_freq[key] = word_freq.get(key, 0) + 1
    return word_freq

def process_chunk(args) -> dict[tuple[bytes, ...], int]:
    """
    Worker entry point for processing a single file chunk.
    """
    input_path, start, end, special_tokens = args
    with open(input_path, "rb") as f:
        f.seek(start)
        chunk = f.read(end - start)

    # Decode with errors ignored to avoid UTF-8 boundary issues
    text = chunk.decode("utf-8", errors="ignore")
    return count_word_freq_from_text(text, special_tokens)

def build_word_freq_serial(input_path : str | os.PathLike, special_tokens: list[str]) -> dict[tuple[bytes, ...], int]:
    init_worker()
    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()
    return count_word_freq_from_text(text, special_tokens)

def build_word_freq_parallel(
    input_path: str | os.PathLike,
    special_tokens: list[str],
    num_processes: int
) -> dict[tuple[bytes, ...], int]:
    """
    Build word frequency statistics using multiprocessing.
    Chunk boundaries are aligned to special-token boundaries.    
    """
    if num_processes <= 1 or not special_tokens:
        return build_word_freq_serial(input_path, special_tokens)
    
    split_special_token = special_tokens[0].encode("utf-8")  # e.g. b"<|endoftext|>"
    with open(input_path, "rb") as f:
        boundaries = find_chunk_boundaries(f, num_processes, split_special_token)
    
    tasks = [(str(input_path), s, e, special_tokens) for s, e in zip(boundaries[:-1], boundaries[1:])]

    merged = Counter()
    with Pool(processes=num_processes, initializer=init_worker) as pool:
        # 
        for partial in pool.imap_unordered(process_chunk, tasks, chunksize=1):
            merged.update(partial)
    
    return dict(merged)

def pairs_in_word(word: tuple[bytes, ...]) -> dict[tuple[bytes, bytes], int]:
    """
    Count adjacent pair occurrences within a single word sequence.
    """
    counts: dict[tuple[bytes, bytes], int] = {}
    if len(word) < 2:
        return counts
    prev = word[0]
    for cur in word[1:]:
        p = (prev, cur)
        counts[p] = counts.get(p, 0) + 1
        prev = cur
    return counts

def apply_merge(word: tuple[bytes, ...], a: bytes, b: bytes, new_token: bytes) -> tuple[bytes, ...]:
    """
    Replace occurrences of (a,b) with new_token.
    """
    if len(word) < 2:
        return word
    merged: list[bytes] = []
    i = 0
    L = len(word)
    while i < L:
        if i < L - 1 and word[i] == a and word[i + 1] ==b:
            merged.append(new_token)
            i += 2
        else:
            merged.append(word[i])
            i += 1
    return tuple(merged)

def build_pair_stats(
    word_freq: dict[tuple[bytes, ...], int]
) -> tuple[dict[tuple[bytes, bytes], int], dict[tuple[bytes, bytes], set[tuple[bytes, ...]]]]:
    """
    Build:
      - pair_counts: global weighted counts for each adjacent pair
      - pair_to_words: inverted index (pair -> set of words containing that pair)
    """
    pair_counts: dict[tuple[bytes, bytes], int] = {}
    pair_to_words: dict[tuple[bytes, bytes], set[tuple[bytes, ...]]] = {}

    for word, freq in word_freq.items():
        if len(word) < 2:
            continue
        local = pairs_in_word(word)
        for p, occ in local.items():
            pair_counts[p] = pair_counts.get(p, 0) + occ * freq
            s = pair_to_words.get(p)
            if s is None:
                pair_to_words[p] = {word}
            else:
                s.add(word)
    
    return pair_counts, pair_to_words

def remove_word_contrib(
    word: tuple[bytes, ...],
    freq: int,
    pair_counts: dict[tuple[bytes, bytes], int],
    pair_to_words: dict[tuple[bytes, bytes], set[tuple[bytes, ...]]],
) -> None:
    """
    Remove a word's contribution from pair_counts and pair_to_words.
    """
    local = pairs_in_word(word)
    for p, occ in local.items():
        s = pair_to_words.get(p)
        if s is not None:
            s.discard(word)
            if not s:
                del pair_to_words[p]
        
        new_c = pair_counts.get(p, 0) - occ * freq
        if new_c <= 0:
            pair_counts.pop(p, None)
        else:
            pair_counts[p] = new_c

def add_word_contrib(
    word: tuple[bytes, ...],
    add_freq: int,
    pair_counts: dict[tuple[bytes, bytes], int],
    pair_to_words: dict[tuple[bytes, bytes], set[tuple[bytes, ...]]],
    *,
    word_is_new: bool,
) -> None:
    """
    Add a word's contribution to pair_counts and pair_to_words.
    """
    if len(word) < 2:
        return
    local = pairs_in_word(word)
    for p, occ in local.items():
        pair_counts[p] = pair_counts.get(p, 0) + occ * add_freq
        if word_is_new:
            s = pair_to_words.get(p)
            if s is None:
                pair_to_words[p] = {word}
            else:
                s.add(word)

def train_bpe(
    input_path: str | os.PathLike,
    vocab_size: int,
    special_tokens: list[str],
    *,
    num_processes: int | None = None
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """
    Train a byte-level BPE tokenizer.

    Returns:
        vocab: dict[int, bytes]
        merges: list[tuple[bytes, bytes]]    
    """
    if vocab_size <= 0:
        raise ValueError("vocab_size must be a positive integer")
    if vocab_size < 256 + len(special_tokens):
        raise ValueError("vocab_size too small: must be >= 256 + len(special_tokens)")

    # ---- vocab init: 256 single-byte tokens + special tokens ----
    vocab: dict[int, bytes] = {i: bytes([i]) for i in range(256)}
    next_id = 256
    for token in special_tokens:
        vocab[next_id] = token.encode("utf-8")
        next_id += 1
    
    # ---- pre-tokenization + counting (parallelization) ----
    if num_processes is None:
        num_processes = min(8, os.cpu_count() or 1)
    
    file_size = os.path.getsize(input_path)

    # # For small files, multiprocessing overhead dominates; use serial
    if num_processes <= 1 or file_size < 1_000_000:  # ~1MB
        word_freq = build_word_freq_serial(input_path, special_tokens)
    else:
        word_freq = build_word_freq_parallel(input_path, special_tokens, num_processes)
    
    if not word_freq:
        return vocab, []

    # ---- BPE merges ----
    pair_counts, pair_to_words = build_pair_stats(word_freq)
    merges: list[tuple[bytes, bytes]] = []
    while next_id < vocab_size:
        if not pair_counts:
            break
    
        # choose most frequent; tie-break by lexicographically largest pair
        (a, b), best_count = max(pair_counts.items(), key=lambda kv: (kv[1], kv[0]))
        if best_count <= 0:
            break
        
        new_token = a + b
        merges.append((a, b))
        vocab[next_id] = new_token
        next_id += 1

        affected = pair_to_words.get((a, b))
        if not affected:
            pair_counts.pop((a, b), None)
            continue

        # replace occurrences of (a,b) in every word
        add_back: dict[tuple[bytes, ...], int] = {}
        for word in list(affected):
            freq = word_freq.get(word)
            if freq is None:
                continue

            remove_word_contrib(word, freq, pair_counts, pair_to_words)
            del word_freq[word]

            new_word = apply_merge(word, a, b, new_token)
            add_back[new_word] = add_back.get(new_word, 0) + freq
        
        for new_word, add_freq in add_back.items():
            existed = new_word in word_freq
            word_freq[new_word] = word_freq.get(new_word, 0) + add_freq
            add_word_contrib(new_word, add_freq, pair_counts, pair_to_words, word_is_new=not existed)

    return vocab, merges

在这里插入图片描述

可以看到优化后的速度提升幅度大,测试时间从之前的 7.71s 降低到了现在的 0.84s,非常可观

在之前旧版(朴素版)中每一轮 merge 都需要做两件 “全量” 的事情

A. 全量重建 pair_counts

pair_counts = {}
for word, freq in word_freq.items():
    pair_counts[pair] += freq

每一轮都扫一遍 所有 word 的相邻 pair

B. 全量重建 new_word_freq

new_word_freq = {}
for word, freq in word_freq.items():
    new_word_freq[new_word] += freq
word_freq = new_word_freq

每一轮都要把 所有 word 都走一遍 “替换+拷贝”

所以旧版瓶颈的本质是:每次 merge 都对整个语料做一次全遍历

新版把全量重建变成只更新受影响的那部分

新版在进入 while 之前先做一次初始化:

  • pair_counts:全局 pair 的加权频次(只建一次)
  • pair_to_words:倒排索引(pair -> 哪些 word 含有这个 pair)

然后每一轮 merge:

1. 直接从 pair_counts 里选最大 pair (a, b)

2.pair_to_words[(a,b)] 只拿到 受影响的 words(也就是包含该 pair 的那些 word)

3. 对这些受影响 words:

  • 把它们 “旧的 pair 贡献” 从全局统计里减掉(remove)
  • 把 word 合并成新 word

4. 把合并后的新 words 加回 word_freq,同时只把这些新 words 的 pair 贡献加回全局统计(add)

这里的关键就是:每轮只处理包含当前最佳 pair 的那些 word,其他完全不相关的 word 不动、不重扫

关于具体的代码实现,我们就简单看下 new_token 构建之后所做的事情(其他和朴素版实现基本一致):

affected = pair_to_words.get((a, b))
if not affected:
    pair_counts.pop((a, b), None)
    continue

在这里插入图片描述

首先通过 new_token 我们拿到语料库中受 new_token 影响的所有 words,也就是这里的 affected 变量。如果其为空则把 pair_counts 的 pair 给删掉,然后跳过这一轮,选下一个 pair;否则继续走后面的逻辑

add_back: dict[tuple[bytes, ...], int] = {}
for word in list(affected):
    freq = word_freq.get(word)
    if freq is None:
        continue

    remove_word_contrib(word, freq, pair_counts, pair_to_words)
    del word_freq[word]

    new_word = apply_merge(word, a, b, new_token)
    add_back[new_word] = add_back.get(new_word, 0) + freq

在这里插入图片描述

第一段循环,把所有 affected 的旧 word 从语料中移除,并生成合并后的新 word

def remove_word_contrib(
    word: tuple[bytes, ...],
    freq: int,
    pair_counts: dict[tuple[bytes, bytes], int],
    pair_to_words: dict[tuple[bytes, bytes], set[tuple[bytes, ...]]],
) -> None:
    """
    Remove a word's contribution from pair_counts and pair_to_words.
    """
    local = pairs_in_word(word)
    for p, occ in local.items():
        s = pair_to_words.get(p)
        if s is not None:
            s.discard(word)
            if not s:
                del pair_to_words[p]
        
        new_c = pair_counts.get(p, 0) - occ * freq
        if new_c <= 0:
            pair_counts.pop(p, None)
        else:
            pair_counts[p] = new_c

在这里插入图片描述

remove_word_contrib(...) 函数用来从全局结构里撤销旧 word 的贡献,其内部逻辑主要是:

  • (A) 找出这个 word 的所有相邻 pair 以及出现次数local
  • (B) 对每个 pair,把 word 从倒排索引里移除
  • (C) 对每个 pair,把 word 的贡献从全局计数里减去,其中 occ 表示这个 pair 在该 word 内出现次数,freq 表示这个 word 在语料中出现次数
def apply_merge(word: tuple[bytes, ...], a: bytes, b: bytes, new_token: bytes) -> tuple[bytes, ...]:
    """
    Replace occurrences of (a,b) with new_token.
    """
    if len(word) < 2:
        return word
    merged: list[bytes] = []
    i = 0
    L = len(word)
    while i < L:
        if i < L - 1 and word[i] == a and word[i + 1] ==b:
            merged.append(new_token)
            i += 2
        else:
            merged.append(word[i])
            i += 1
    return tuple(merged)

del word_freq[word]

new_word = apply_merge(word, a, b, new_token)
add_back[new_word] = add_back.get(new_word, 0) + freq

在这里插入图片描述

接着通过 del 真正从语料库删除这个旧 word,并通过 apply_merge 函数生成合并后的新 word,最后用 add_back 收集所有变成 new_word 的旧 word 频次之和

for new_word, add_freq in add_back.items():
    existed = new_word in word_freq
    word_freq[new_word] = word_freq.get(new_word, 0) + add_freq
    add_word_contrib(new_word, add_freq, pair_counts, pair_to_words, word_is_new=not existed)

在这里插入图片描述

第二段循环,把合并后的新 word 加回语料库,并把它们的贡献加入全局结构,extsted 主要用于倒排索引更新

def add_word_contrib(
    word: tuple[bytes, ...],
    add_freq: int,
    pair_counts: dict[tuple[bytes, bytes], int],
    pair_to_words: dict[tuple[bytes, bytes], set[tuple[bytes, ...]]],
    *,
    word_is_new: bool,
) -> None:
    """
    Add a word's contribution to pair_counts and pair_to_words.
    """
    if len(word) < 2:
        return
    local = pairs_in_word(word)
    for p, occ in local.items():
        pair_counts[p] = pair_counts.get(p, 0) + occ * add_freq
        if word_is_new:
            s = pair_to_words.get(p)
            if s is None:
                pair_to_words[p] = {word}
            else:
                s.add(word)

在这里插入图片描述

add_word_contrib(...) 函数用来把新 word 对全局统计的贡献加回去,刚好和 remove_word_contrib 做的事情相反:

  • (A) 计算新 word 的所有相邻 pair 及出现的次数local
  • (B) 对每个 pair,把贡献加入全局计数
  • (C) 如果 word_is_new,把 new_word 加入倒排索引

OK,以上就是 BPE 分词器训练代码的实现以及相关优化了,我们接着看下一个作业

4. Problem (train_bpe_tinystories): BPE Training on TinyStories (2 points)

(a) 在 TinyStories 数据集上训练一个 字节级 BPE 分词器,最大词表大小设为 10,000,请务必将 TinyStories 的特殊 token <|endoftext|> 加入词表中,将训练得到的词表和合并规则序列化保存到磁盘中,以便后续检查

请回答以下问题:

  • 训练过程大约耗时了多少小时?占用了多少内存?
  • 词表中最长的 token 是什么?这个结果是否合理?

Resource requirements:≤ 30 分钟(不使用 GPU),≤ 30GB 内存

Hint:如果在预分词阶段使用 multiprocessing,你应当能够在 2 分钟以内 完成 BPE 训练,可以利用以下两个事实:

  • (a) <|endoftext|> token 在数据文件中用于分隔不同文档;
  • (b) <|endoftext|> token 在 BPE 合并开始之前会作为特殊情况单独处理。

在这里插入图片描述

Deliverable在 TinyStories 数据集上训练一个词表大小为 10,000 的字节级 BPE 分词器大约耗时 1.9 分钟,内存占用约 0.08GB。训练得到的词表中最长的 token 是 " accomplishment",这是合理的,因为字节级 BPE 会倾向将高频的多字节子串(通常包含前导空格的常见单词)合并为单个 token。

(b) 对你的代码进行性能分析(profiling),在整个分词器训练过程中,哪一部分耗时最多?

在这里插入图片描述

Deliverableprofiling 表明耗时最多的是合并阶段选择最频繁 pair(max(pair_counts.items(), key=...))以及并行预分词的汇总阶段(大量 _thread.lock.acquire),说明瓶颈主要在 pair 选择扫描与多进程合并时的锁竞争。

关于 TinyStories 的训练代码 train_bpe_tinystories.py 实现如下:

import os
import time
import pickle
import psutil

from cs336_basics.train_bpe import train_bpe

def main():
    input_path = "data/TinyStoriesV2-GPT4-train.txt"
    output_dir = "workspace"
    os.makedirs(output_dir, exist_ok=True)

    vocab_size = 10_000
    special_tokens = ["<|endoftext|>"]

    proc = psutil.Process(os.getpid())
    peak_rss = 0

    t0 = time.perf_counter()
    vocab, merges = train_bpe(
        input_path=input_path,
        vocab_size=vocab_size,
        special_tokens=special_tokens,
        num_processes=8
    )
    t1 = time.perf_counter()

    peak_rss = proc.memory_info().rss

    # Save to disk
    vocab_path = os.path.join(output_dir, "tinystories_bpe_vocab.pkl")
    merges_path = os.path.join(output_dir, "tinystories_bpe_merges.pkl")
    with open(vocab_path, "wb") as f:
        pickle.dump(vocab, f)
    with open(merges_path, "wb") as f:
        pickle.dump(merges, f)

    # Longest token in vocab (by byte length)
    longest_id, longest_bytes = max(vocab.items(), key=lambda kv: len(kv[1]))
    longest_str = longest_bytes.decode("utf-8", errors="replace")

    elapsed_s = t1 - t0
    elapsed_min = elapsed_s / 60.0
    elapsed_hr  = elapsed_s / 3600.0
    mem_gb = peak_rss / (1024**3) 

    print(f"Saved vocab -> {vocab_path}")
    print(f"Saved merges -> {merges_path}")
    print(f"Elapsed: {elapsed_s:.2f}s ({elapsed_min:.2f} min, {elapsed_hr:.4f} hr)")
    print(f"RSS (approx): {mem_gb:.2f} GB  (install psutil for this number)")
    print(f"Longest token id={longest_id}, bytes_len={len(longest_bytes)}")
    print(f"Longest token (decoded): {repr(longest_str)}")

if __name__ == "__main__":
    main()

运行指令如下:

# 1. download data
mkdir -p data
cd data

wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt
wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-valid.txt

# 2. run
uv run cs336_basics/train_bpe_tinystories.py

输出如下图所示:

在这里插入图片描述

对代码进行 profile 的指令如下:

# 1. profile
uv run python -m cProfile -o workspace/tinystories.prof cs336_basics/train_bpe_tinystories.py
# 2. check
uv run python - <<'EOF'
import pstats
p = pstats.Stats("workspace/tinystories.prof")
p.sort_stats("tottime").print_stats(30)
EOF

5. Problem (train_bpe_expts_owt): BPE Training on OpenWebText (2 points)

接下来,我们将尝试在 OpenWebText 数据集上训练一个字节级 BPE 分词器,和之前一样,建议你先浏览数据集内容,以更好地理解其中的数据分布

(a) 在 OpenWebText 数据集上训练一个 字节级 BPE 分词器,最大词表大小设为 32,000,将训练得到的词表和合并规则序列化保存到磁盘中,以便后续检查

请回答以下问题:词表中最长的 token 是什么?这个结果是否合理?

Resource requirements:≤ 12 小时(不使用 GPU),≤ 100GB 内存

Deliverable最长 token 解码后显示为 'ÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂ',这类乱码在字节级 BPE 里是合理的,因为 OpenWebText 含有大量网页噪声/错误编码文本以及长重复模式,字节级合并会把频繁出现的字节序列合并成很长的 token,但他未必对应可读的 UTF-8 字符串,因此解码展示为替代字符或 'ÃÂ...' 一类的表现是正常现象。

(b) 比较并分析你在 TinyStories 与 OpenWebText 上训练得到的分词器之间的异同

Deliverable两者都学到了一批常见的英文片段如空格+常见词、标点组合等,但 TinyStories 更干净,文本分布更集中,因此最长 token 往往是可读的正常词片段,比如 ' accomplishment'。相比之下 OpenWebText 更网页化、噪声多,包含乱码与长重复符号,所以会学到更多非自然语言的 token,比如异常编码字符串,toklen 的可读性更差但对压缩这些高频噪声模式是有利的。

关于 OpenWebText 的训练代码 train_bpe_owt.py 实现如下:

import os
import time
import pickle
import psutil
from cs336_basics.train_bpe import train_bpe

def main():
    input_path = "data/owt_train.txt"
    output_dir = "workspace"
    os.makedirs(output_dir, exist_ok=True)

    vocab_size = 32_000
    special_tokens = ["<|endoftext|>"]

    proc = psutil.Process(os.getpid())

    t0 = time.perf_counter()
    vocab, merges = train_bpe(
        input_path=input_path,
        vocab_size=vocab_size,
        special_tokens=special_tokens,
        num_processes=8
    )
    t1 = time.perf_counter()

    rss_gb = proc.memory_info().rss / (1024 ** 3)

    vocab_path = os.path.join(output_dir, "owt_bpe_vocab_32000.pkl")
    merges_path = os.path.join(output_dir, "owt_bpe_merges_32000.pkl")
    with open(vocab_path, "wb") as f:
        pickle.dump(vocab, f)
    with open(merges_path, "wb") as f:
        pickle.dump(merges, f)

    longest_id, longest_bytes = max(vocab.items(), key=lambda kv: len(kv[1]))
    longest_str = longest_bytes.decode("utf-8", errors="replace")

    print(f"Saved vocab -> {vocab_path}")
    print(f"Saved merges -> {merges_path}")
    print(f"Elapsed: {(t1 - t0):.2f}s")
    print(f"RSS (approx): {rss_gb:.2f} GB")
    print(f"Longest token id={longest_id}, bytes_len={len(longest_bytes)}")
    print(f"Longest token (decoded): {repr(longest_str)}")

if __name__ == "__main__":
    main()

运行指令如下:

# 1. download data
mkdir -p data
cd data

wget https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_train.txt.gz
gunzip owt_train.txt.gz
wget https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_valid.txt.gz
gunzip owt_valid.txt.gz

# 2. run
uv run python scripts/train_bpe_owt.py

博主在用上面的代码进行测试时出现 OOM(out of memory)的问题,Debug 后发现卡在 build_word_freq_parallel 函数中,随后程序直接崩溃,电脑卡死。在对 TinyStories 以及 OpenWebText 的验证集测试时不存在这个问题,博主测试机是 64GB 内存,而 owt_train.txt 是 12GB 大小

出现这个问题的原因是在 OWT 上并行构建 word_freq 时的每个 chunk 太大,我们目前的 boundary 是:

boundaries = find_chunk_boundaries(f, num_processes, split_special_token)

num_processes 均分,也就是 8 个 chunk,每个 chunk 大约 12GB/8 ≈ 1.5GB,每个 worker 线程会做如下事情:

  • f.read(end-start) 读入 ~1.5GB bytes
  • decode("utf-8") 变成巨大的 Python str,还会额外占内存
  • 正则 finditer 遍历、构造大量 key
  • 返回一个很大的 dict 给主线程合并

在 8 个进程并发时,峰值内存远超我们的想象,不仅是 8x1.5GB 文本,还有巨大的 Python 对象开销

我们优化的方向是减少并发峰值,不再用 chunk 数等于进程数,把 chunk 切成很多小块,让 worker 轮流处理,同时加上 maxtasksperchild 防止 worker 内存碎片/增长

代码修改如下:

def build_word_freq_parallel(
    input_path: str | os.PathLike,
    special_tokens: list[str],
    num_processes: int,
    *,
    num_chunks: int | None = None
) -> dict[tuple[bytes, ...], int]:
    """
    Build word frequency statistics using multiprocessing.
    Chunk boundaries are aligned to special-token boundaries.    
    """
    if num_processes <= 1 or not special_tokens:
        return build_word_freq_serial(input_path, special_tokens)
    
    if num_chunks is None:
        num_chunks = max(num_processes * 32, num_processes)

    split_special_token = special_tokens[0].encode("utf-8")  # e.g. b"<|endoftext|>"
    with open(input_path, "rb") as f:
        boundaries = find_chunk_boundaries(f, num_chunks, split_special_token)
    
    tasks = [(str(input_path), s, e, special_tokens) for s, e in zip(boundaries[:-1], boundaries[1:])]

    merged = Counter()
    with Pool(processes=num_processes, initializer=init_worker, maxtasksperchild=8) as pool:
        for partial in pool.imap_unordered(process_chunk, tasks, chunksize=1):
            merged.update(partial)
    
    return dict(merged)

word_freq = build_word_freq_parallel(input_path, special_tokens, num_processes, num_chunks=num_processes * 32)

修改之后再次执行训练脚本,输出如下:

在这里插入图片描述

6. Problem (tokenizer): Implementing the tokenizer (15 points)

Deliverable:请实现一个 Tokenizer 类,该类在给定词表和合并规则列表的情况下,能够将文本编码为整数形式的 token ID 序列,并将 token ID 序列解码回文本。你的分词器还应当支持 用户自定义的特殊 token,如果这些 token 尚未存在于词表中,则需要将其追加到词表中

我们推荐使用如下接口设计:

def __init__(self, vocab, merges, special_tokens=None)

使用给定的词表、合并规则列表以及(可选的)特殊 token 列表来构造一个分词器

参数说明:

  • vocab: dict[int, bytes]:词表,从 token ID(整数)映射到对应的字节序列
  • merges: list[tuple[bytes, bytes]]:BPE 合并规则列表
  • special_tokens: list[str] | None = None:可选的特殊 token 列表
def from_files(cls, vocab_filepath, merges_filepath, special_tokens=None)

这是一个类方法,用于从 序列化保存的词表和合并规则文件 中构建并返回一个 Tokenizer 实例,文件格式应与你的 BPE 训练代码输出格式一致,并可选地接收一组特殊 token

额外参数说明:

  • vocab_filepath: str:词表文件路径
  • merges_filepath: str:合并规则路径
  • special_tokens: list[str] | None = None:可选的特殊 token 列表
def encode(self, text: str) -> list[int]

将输入文本编码为 token ID 序列

def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]

给定一个字符串的可迭代对象(例如 Python 的文件句柄),返回一个 惰性生成器,逐个生成 token ID,该接口用于 内存高效地对无法整体载入内存的大型文件进行分词

def decode(self, ids: list[int]) -> str

将一组 token ID 解码为文本字符串

为了使用我们提供的测试用例来验证你的 Tokenizer 实现,你需要先实现测试适配器 [adapters.get_tokenizer],然后运行:

uv run pytest tests/test_tokenizer.py

你的实现应当能够通过所有测试。

tokenizer.py 实现代码如下:

import os
import pickle
import regex as re
from typing import Any, Iterable, Iterator

GPT2_PRETOKENIZE_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

class Tokenizer:
    """
    Byte-level BPE tokenizer compatible with GPT-2-style pre-tokenization.
    Vocab maps token_id -> bytes. Merges are list[(bytes, bytes)] in creation order.
    """

    def __init__(
        self,
        vocab: dict[int, bytes],
        merges: list[tuple[bytes, bytes]],
        special_tokens: list[str] | None = None,
    ):
        self.vocab: dict[int, bytes] = vocab
        self.merges: list[tuple[bytes, bytes]] = merges
        
        # reverse vocab: bytes -> id
        self.byte_to_id: dict[bytes, int] = {b: i for i, b in self.vocab.items()}

        # merge ranks: lower rank = higher priority
        self.merge_ranks: dict[tuple[bytes, bytes], int] = {
            pair: idx for idx, pair in enumerate(self.merges)
        }

        # regex for GPT-2 pre-tokenization
        self.rx = re.compile(GPT2_PRETOKENIZE_PATTERN)

        # special tokens
        self.special_tokens: list[str] = special_tokens or []
        self.special_bytes: list[bytes] = []
        self.special_id: dict[str, int] = {}

        if self.special_tokens:
            # append missing special tokens to vocab
            for s in self.special_tokens:
                b = s.encode("utf-8")
                if b not in self.byte_to_id:
                    new_id = len(self.vocab)
                    self.vocab[new_id] = b
                    self.byte_to_id[b] = new_id
                self.special_id[s] = self.byte_to_id[b]
                self.special_bytes.append(b)
            
            # build a "longest-first" special-token matcher
            # we keep them as strings for boundary-safe matching in encode()
            sorted_special = sorted(self.special_tokens, key=len, reverse=True)
            self._special_re = re.compile("|".join(re.escape(s) for s in sorted_special))
            self._max_special_len = max(len(s) for s in self.special_tokens)
        else:
            self._special_re = None
            self._max_special_len = 0

        # cache for BPE on pre-token byte sequence
        self._bpe_cache: dict[bytes, list[bytes]] = {}

    @classmethod
    def from_files(
        cls,
        vocab_filepath: str | os.PathLike,
        merges_filepath: str | os.PathLike,
        special_tokens: list[str] | None = None
    ) -> "Tokenizer":
        # match the training outputs used earlier (pickle dump)
        with open(vocab_filepath, "rb") as f:
            vocab = pickle.load(f)
        with open(merges_filepath, "rb") as f:
            merges = pickle.load(f)
        return cls(vocab=vocab, merges=merges, special_tokens=special_tokens)
    
    # ---------------------------
    # Public API
    # ---------------------------
    def encode(self, text: str) -> list[int]:
        if not text:
            return []    
        
        ids: list[int] = []
        if not self._special_re:
            ids.extend(self._encode_plain(text))
            return ids
        
        # split text by special tokens while preserving them as standalone parts
        last = 0
        for m in self._special_re.finditer(text):
            if m.start() > last:
                ids.extend(self._encode_plain(text[last : m.start()]))
            s = m.group(0)
            ids.append(self.special_id[s])
            last = m.end()
        if last < len(text):
            ids.extend(self._encode_plain(text[last:]))
        return ids
    
    def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:
        """
        Memory-efficient streaming encode that matches Tokenizer.encode(full_text)
        """
        buf = ""

        for chunk in iterable:
            if not chunk:
                continue
            buf += chunk

            while True:
                matches = list(self.rx.finditer(buf))
                if len(matches) <= 1:
                    break

                # keep the last match unprocessed; emit everything before it.
                cut = matches[-1].start()
                if cut <= 0:
                    break

                process_part = buf[:cut]
                buf = buf[cut:]

                for _id in self.encode(process_part):
                    yield _id
        
        # flush the remainder
        if buf:
            for _id in self.encode(buf):
                yield _id

    def decode(self, ids: list[int]) -> str:
        if not ids:
            return ""            
        b = b"".join(self.vocab[i] for i in ids)
        return b.decode("utf-8", errors="replace")

    # ---------------------------
    # Internal helpers
    # ---------------------------
    def _encode_plain(self, text: str) -> list[int]:
        """
        Encode a piece of text with no special tokens inside it.
        """
        out: list[int] = []
        for m in self.rx.finditer(text):
            piece = m.group(0)
            if not piece:
                continue
            piece_bytes = piece.encode("utf-8")
            for tok_bytes in self._bpe(piece_bytes):
                out.append(self.byte_to_id[tok_bytes])
        return out
    
    def _bpe(self, token_bytes: bytes) -> list[bytes]:
        """
        Apply BPE merges (by rank) on a single pre-token byte sequence.
        Returns a list of vocab byte tokens.        
        """
        cached = self._bpe_cache.get(token_bytes)
        if cached is not None:
            return cached
        
        # start from single bytes
        word: list[bytes] = [bytes([b]) for b in token_bytes]
        if len(word) <= 1:
            self._bpe_cache[token_bytes] = word
            return word
        
        while True:
            best_pair = None
            best_rank = None

            # find best ranked pair among adjacent pairs
            prev = word[0]
            for cur in word[1:]:
                p = (prev, cur)
                r = self.merge_ranks.get(p)
                if r is not None and (best_rank is None or r < best_rank):
                    best_rank = r
                    best_pair = p
                prev = cur
            
            if best_pair is None:
                break
        
            a, b = best_pair
            new_token = a + b

            # merge all occurrences of (a, b)
            merged: list[bytes] = []
            i = 0
            L = len(word)
            while i < L:
                if i < L - 1 and word[i] == a and word[i + 1] == b:
                    merged.append(new_token)
                    i += 2
                else:
                    merged.append(word[i])
                    i += 1
            word = merged
            if len(word) <= 1:
                break
        
        self._bpe_cache[token_bytes] = word
        return word

测试适配器 [adapters.get_tokenizer] 的实现如下:

def get_tokenizer(
    vocab: dict[int, bytes],
    merges: list[tuple[bytes, bytes]],
    special_tokens: list[str] | None = None,
) -> Any:
    """Given a vocabulary, a list of merges, and a list of special tokens,
    return a BPE tokenizer that uses the provided vocab, merges, and special tokens.

    Args:
        vocab (dict[int, bytes]): The tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
            to bytes (token bytes)
        merges (list[tuple[bytes, bytes]]): BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
            representing that <token1> was merged with <token2>.
            Merges are ordered by order of creation.
        special_tokens (list[str] | None): A list of string special tokens for the tokenizer. These strings will never
            be split into multiple tokens, and will always be kept as a single token.

    Returns:
        A BPE tokenizer that uses the provided vocab, merges, and special tokens.
    """
    from cs336_basics.tokenizer import Tokenizer
    return Tokenizer(vocab=vocab, merges=merges, special_tokens=special_tokens)

执行 uv run pytest tests/test_tokenizer.py 后输出如下:

在这里插入图片描述

可以看到所有测试都通过了,有一个 XFAIL 的警告我们不用理会,因为这个测试设计上就是预期失败的,Tokenizer.encode() 在大文本上不可能做到 1MB 内存以内

代码中 Tokenizer 类 encode 时的整体流程是:

1. 先在字符串层面处理 special tokens(如果有),保证它们永远不被拆开,也不会参与普通分词

2. 对普通文本部分做 GPT-2 regex pre-tokenization 预分词(self.rx.finditer

3. 每个预分词(pre-token)片段 piece

  • piece.encode("utf-8) 得到 bytes
  • bytes 按单字节拆成 [b0, b1, ...]
  • 在这个序列上按 merges 的 rank 做 BPE 合并

4. 最终每个 BPE token 用 byte_to_id 查表变成 token id

decode 则反过来:把 ids 对应的 bytes 串起来,再 decode 回 utf-8(用 replace)

下面我们来简单看看代码具体是如何实现的:

def __init__(
    self,
    vocab: dict[int, bytes],
    merges: list[tuple[bytes, bytes]],
    special_tokens: list[str] | None = None,
):
    self.vocab: dict[int, bytes] = vocab
    self.merges: list[tuple[bytes, bytes]] = merges
    
    # reverse vocab: bytes -> id
    self.byte_to_id: dict[bytes, int] = {b: i for i, b in self.vocab.items()}

    # merge ranks: lower rank = higher priority
    self.merge_ranks: dict[tuple[bytes, bytes], int] = {
        pair: idx for idx, pair in enumerate(self.merges)
    }

    # regex for GPT-2 pre-tokenization
    self.rx = re.compile(GPT2_PRETOKENIZE_PATTERN)

    # special tokens
    self.special_tokens: list[str] = special_tokens or []
    self.special_bytes: list[bytes] = []
    self.special_id: dict[str, int] = {}

    if self.special_tokens:
        # append missing special tokens to vocab
        for s in self.special_tokens:
            b = s.encode("utf-8")
            if b not in self.byte_to_id:
                new_id = len(self.vocab)
                self.vocab[new_id] = b
                self.byte_to_id[b] = new_id
            self.special_id[s] = self.byte_to_id[b]
            self.special_bytes.append(b)
        
        # build a "longest-first" special-token matcher
        # we keep them as strings for boundary-safe matching in encode()
        sorted_special = sorted(self.special_tokens, key=len, reverse=True)
        self._special_re = re.compile("|".join(re.escape(s) for s in sorted_special))
        self._max_special_len = max(len(s) for s in self.special_tokens)
    else:
        self._special_re = None
        self._max_special_len = 0

    # cache for BPE on pre-token byte sequence
    self._bpe_cache: dict[bytes, list[bytes]] = {}

初始化函数 __init__ 主要用来构建 tokenizer 分词器所需的索引结构

在这里插入图片描述

首先保持原始 vocabmerges 输入,vocab 词表是 token id 到 token bytes 的映射,merges 是 BPE 合并规则序列;接着构建 bytes -> id 的反向表 byte_to_id,构建 merge_rank,编译 GPT-2 预分词 self.rx

在这里插入图片描述

关于 special token 的处理,我们要保证它们不会被拆,需要:

  • (a) 把 special token 追加到 vocab(如果 vocab 里没有)
  • (b) 建立 special token 的 regex 匹配器(最长优先)

最后的 _bpe_cache BPE 缓存用于复用同一个 pre-token 的结果,节省时间

@classmethod
def from_files(
    cls,
    vocab_filepath: str | os.PathLike,
    merges_filepath: str | os.PathLike,
    special_tokens: list[str] | None = None
) -> "Tokenizer":
    # match the training outputs used earlier (pickle dump)
    with open(vocab_filepath, "rb") as f:
        vocab = pickle.load(f)
    with open(merges_filepath, "rb") as f:
        merges = pickle.load(f)
    return cls(vocab=vocab, merges=merges, special_tokens=special_tokens)

from_files 类方法用于从磁盘加载 vocab/merges 构造 Tokenizer

def encode(self, text: str) -> list[int]:
    if not text:
        return []    
    
    ids: list[int] = []
    if not self._special_re:
        ids.extend(self._encode_plain(text))
        return ids
    
    # split text by special tokens while preserving them as standalone parts
    last = 0
    for m in self._special_re.finditer(text):
        if m.start() > last:
            ids.extend(self._encode_plain(text[last : m.start()]))
        s = m.group(0)
        ids.append(self.special_id[s])
        last = m.end()
    if last < len(text):
        ids.extend(self._encode_plain(text[last:]))
    return ids

encode 方法用于将输入文本字符串转换成对应的 token ids

在没有 special token 的情况下直接调用 self._encode_plain 走普通流程;有 special token 的情况,按 special token 切分并保留它们,主要处理方式是:

  • (a) special token 之间的普通片段走 _encode_plain
  • (b) special token 直接 append 对应 id,
  • (c) 末尾剩余文本继续走 _encode_plain

这样使得 encode 的输出保证 special token 永远作为一个独立 token,同时普通文本永远不会跨过 special token 边界做 BPE 合并

def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:
    """
    Memory-efficient streaming encode that matches Tokenizer.encode(full_text)
    """
    buf = ""

    for chunk in iterable:
        if not chunk:
            continue
        buf += chunk

        while True:
            matches = list(self.rx.finditer(buf))
            if len(matches) <= 1:
                break

            # keep the last match unprocessed; emit everything before it.
            cut = matches[-1].start()
            if cut <= 0:
                break

            process_part = buf[:cut]
            buf = buf[cut:]

            for _id in self.encode(process_part):
                yield _id
    
    # flush the remainder
    if buf:
        for _id in self.encode(buf):
            yield _id

encode_iterable 方法是流式编码(内存友好),这个方法的目标是对文件这种 “一行一行读出来的 chunk” 做编码,但保证结果和 encode 方法完全一致

在这里插入图片描述

我们的实现策略是用 buf 缓冲,同时每次留最后一个 regex match 不处理,最后文件读完后再把剩余的全部处理掉

def decode(self, ids: list[int]) -> str:
    if not ids:
        return ""            
    b = b"".join(self.vocab[i] for i in ids)
    return b.decode("utf-8", errors="replace")

decode 方法用于将 token ids 转换成对于字符串,注意这里用 replace 很关键,因为 GPT-2 的 tokenizer 允许产生任意 byte 序列,单个 token 可能不是合法 utf-8(例如 0xC3 这种半个字符前缀),replace 可以避免报错

def _encode_plain(self, text: str) -> list[int]:
    """
    Encode a piece of text with no special tokens inside it.
    """
    out: list[int] = []
    for m in self.rx.finditer(text):
        piece = m.group(0)
        if not piece:
            continue
        piece_bytes = piece.encode("utf-8")
        for tok_bytes in self._bpe(piece_bytes):
            out.append(self.byte_to_id[tok_bytes])
    return out

在这里插入图片描述

_encode_plain 方法才是真正将字符串(无 special token)转换为对应 ids,其中:

  • pre-token 是 regex 切出来的字符串片段
  • 每个片段转 bytes 后独立做 BPE(不会跨片段合并)
  • _bpe 输出的是 bytes token 序列,再映射到 id
def _bpe(self, token_bytes: bytes) -> list[bytes]:
    """
    Apply BPE merges (by rank) on a single pre-token byte sequence.
    Returns a list of vocab byte tokens.        
    """
    cached = self._bpe_cache.get(token_bytes)
    if cached is not None:
        return cached
    
    # start from single bytes
    word: list[bytes] = [bytes([b]) for b in token_bytes]
    if len(word) <= 1:
        self._bpe_cache[token_bytes] = word
        return word
    
    while True:
        best_pair = None
        best_rank = None

        # find best ranked pair among adjacent pairs
        prev = word[0]
        for cur in word[1:]:
            p = (prev, cur)
            r = self.merge_ranks.get(p)
            if r is not None and (best_rank is None or r < best_rank):
                best_rank = r
                best_pair = p
            prev = cur
        
        if best_pair is None:
            break
    
        a, b = best_pair
        new_token = a + b

        # merge all occurrences of (a, b)
        merged: list[bytes] = []
        i = 0
        L = len(word)
        while i < L:
            if i < L - 1 and word[i] == a and word[i + 1] == b:
                merged.append(new_token)
                i += 2
            else:
                merged.append(word[i])
                i += 1
        word = merged
        if len(word) <= 1:
            break
    
    self._bpe_cache[token_bytes] = word
    return word

在这里插入图片描述

_bpe 方法是 encode 核心步骤,用于对一个 pre-token 的 bytes 序列做 BPE 合并

首选先检查缓存,如果缓存中存在则直接返回,否则将输入的 token_bytes 初始化为单字节序列。接着循环,每轮选 rank 最小的 pair 合并,并把这个 pair 在整个 word 序列里全部合并掉,直到没有可合并 pair

最终的 word 就是一串 bytes tokens,缓存并返回即可

OK,以上就是 Tokenizer 分词器的代码实现了

7. Problem (tokenizer_experiments): Experiments with tokenizers (4 points)

(a) 分别从 TinyStoriesOpenWebText 中各随机抽取 10 篇文档,使用你之前训练好的 TinyStories 分词器和 OpenWebText 分词器(词表大小分别为 10K 和 32K),将这些抽样文档编码为整数形式的 token ID

请计算:每个分词器的压缩率是多少(以 bytes/token 表示)?

DeliverableTinyStories 分词器的压缩率约 4.0536 bytes/token,OpenWebText 分词器的压缩率约 4.5357 bytes/token。

(b) 如果你使用 TinyStories 分词器OpenWebText 的样本进行分词,会发生什么情况?请比较其压缩率,并对观察到的现象进行定性描述

Deliverable使用 TinyStories 分词器对 OpenWebText 分词会导致分词粒度明显变碎,token 数量显著增加,压缩率从 4.54 降到 3.31,压缩效果明显变差,这是因为 TinyStories 训练语料过于简单,无法覆盖 OpenWebText 的复杂语言结构。

(c) 估计你的分词器的吞吐率(例如,以 bytes/second 表示),对 82GB 的文本数据进行分词大概需要多长时间?

DeliverableOpenWebText 分词器吞吐率约 3.85 MB/s,据此估算对 82GB 文本分词约需 6.36 小时。

(d) 使用你训练好的 TinyStories 和 OpenWebText 分词器,将各种的训练集与开发集编码为整数形式的 token ID 序列,我们将在后续使用这些数据来训练语言模型。我们建议将 token ID 序列序列化保存为 uint16 类型的 NumPy 数组,为什么 uint16 是一个合适的选择?

Deliverable因为我们当前的词表大小分别是 10K 和 32K,都小于 65535,用 uint16 可以无损容纳所有 token ID,同时每个 token 只占 2 字节,存储和 I/O 都更省。

实验测试代码如下所示:

import os
import time
import random
import numpy as np
from typing import Iterator, TextIO

from cs336_basics.tokenizer import Tokenizer

EOT = "<|endoftext|>"

def iter_docs_by_eot(f: TextIO, eot: str) -> Iterator[str]:
    buf = ""
    for chunk in f:
        buf += chunk
        while True:
            idx = buf.find(eot)
            if idx < 0:
                break
            doc = buf[:idx]
            buf = buf[idx + len(eot) :]
            if doc.strip():
                yield doc
    # Emit the tail as the last doc (if non-empty).
    if buf.strip():
        yield buf


def reservoir_sample_docs(path: str, eot: str, k: int, seed: int) -> list[str]:
    rnd = random.Random(seed)
    sample: list[str] = []
    n = 0
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        for doc in iter_docs_by_eot(f, eot):
            n += 1
            if len(sample) < k:
                sample.append(doc)
            else:
                j = rnd.randrange(n)
                if j < k:
                    sample[j] = doc
    return sample

def bytes_per_token(tokenizer: Tokenizer, docs: list[str]) -> tuple[float, int, int]:
    total_bytes = 0
    total_tokens = 0
    for d in docs:
        b = len(d.encode("utf-8"))
        ids = tokenizer.encode(d)
        total_bytes += b
        total_tokens += len(ids)
    bpt = total_bytes / max(1, total_tokens)
    return bpt, total_bytes, total_tokens

def iter_first_n_bytes_as_lines(path: str, nbytes: int) -> Iterator[str]:
    seen = 0
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            if not line:
                break
            seen += len(line.encode("utf-8"))
            yield line
            if seen >= nbytes:
                break

def measure_throughput_bytes_per_sec(tokenizer: Tokenizer, it: Iterator[str], repeats: int = 1) -> float:
    lines = list(it)
    total_bytes = sum(len(x.encode("utf-8")) for x in lines)

    t0 = time.perf_counter()
    for _ in range(repeats):
        for _id in tokenizer.encode_iterable(lines):
            pass
    t1 = time.perf_counter()

    secs = (t1 - t0) / max(1, repeats)
    return total_bytes / max(1e-9, secs)

def encode_to_uint16_bin(tokenizer: Tokenizer, input_path: str, output_path: str, *, chunk_tokens: int = 1_000_000) -> None:
    """
    Stream-encode a large text file into uint16 token IDs and write to a .bin file

    This avoids storing the entire token sequence in RAM by buffering a fixed number
    of token IDs and flushing them to disk periodically
    """
    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)

    # Use append-binary mode so we can flush multiple chunks
    with open(input_path, "r", encoding="utf-8", errors="ignore") as fin, open(output_path, "wb") as fout:
        buf = np.empty(chunk_tokens, dtype=np.uint16)
        n = 0

        for tid in tokenizer.encode_iterable(fin):
            # safety check: ensure tid fits into uint16
            if tid < 0 or tid > 65535:
                raise ValueError(f"token id {tid} out of uint16 range")
            
            buf[n] = tid
            n += 1

            if n == chunk_tokens:
                fout.write(buf.tobytes())
                n = 0
        
        # Flush tail
        if n:
            fout.write(buf[:n].tobytes())

def main():
    # 1) paths to trained tokenizers
    tin_vocab_path = "workspace/tinystories_bpe_vocab.pkl"
    tin_merges_path = "workspace/tinystories_bpe_merges.pkl"

    owt_vocab_path = "workspace/owt_bpe_vocab_32000.pkl"
    owt_merges_path = "workspace/owt_bpe_merges_32000.pkl"

    # 2) paths to datasets
    tinystories_path = "data/TinyStoriesV2-GPT4-train.txt"
    owt_path = "data/owt_train.txt"

    # 3) load tokenizers
    tin_tok = Tokenizer.from_files(tin_vocab_path, tin_merges_path, special_tokens=[EOT])
    owt_tok = Tokenizer.from_files(owt_vocab_path, owt_merges_path, special_tokens=[EOT])

    # 4) sample 10 documents from each corpus WITHOUT loading the full file
    seed = 42
    print("Sampling TinyStories docs (streaming)...")
    tin_docs = reservoir_sample_docs(tinystories_path, EOT, k=10, seed=seed)
    print("Sampling OWT docs (streaming)...")
    owt_docs = reservoir_sample_docs(owt_path, EOT, k=10, seed=seed)

    # (a) In-domain compression efficiency
    tin_on_tin, tin_bytes, tin_tokens = bytes_per_token(tin_tok, tin_docs)
    owt_on_owt, owt_bytes, owt_tokens = bytes_per_token(owt_tok, owt_docs)

    # (b) Cross-domain compression efficiency
    tin_on_owt, x_bytes, x_tokens = bytes_per_token(tin_tok, owt_docs)

    print("\n=== (a) In-domain bytes/token ===")
    print(f"TinyStories tokenizer on TinyStories: {tin_on_tin:.4f} bytes/token "
          f"(bytes={tin_bytes}, tokens={tin_tokens})")
    print(f"OWT tokenizer on OWT:               {owt_on_owt:.4f} bytes/token "
          f"(bytes={owt_bytes}, tokens={owt_tokens})")

    print("\n=== (b) Cross-domain bytes/token ===")
    print(f"TinyStories tokenizer on OWT:       {tin_on_owt:.4f} bytes/token "
          f"(bytes={x_bytes}, tokens={x_tokens})")

    # (c) Throughput estimation
    # use a small prefix of the corpus to avoid long measurement times

    sample_for_speed = iter_first_n_bytes_as_lines(owt_path, nbytes=5_000_000)  # 5MB
    thr = measure_throughput_bytes_per_sec(owt_tok, sample_for_speed, repeats=1)
    total_bytes_82gb = 82 * (1024 ** 3)
    est_secs = total_bytes_82gb / thr
    est_hours = est_secs / 3600

    print("\n=== (c) Throughput estimate ===")
    print(f"Measured throughput (OWT tokenizer): {thr:.2f} bytes/s")
    print(f"Estimated time for 82GB: {est_hours:.2f} hours")

    # (d) Serialize token IDs for LM training
    print("\n=== (d) Encoding datasets to uint16 ===")
    
    encode_to_uint16_bin(tin_tok, "data/TinyStoriesV2-GPT4-train.txt", "workspace/tinystories_train.uint16.bin")
    encode_to_uint16_bin(tin_tok, "data/TinyStoriesV2-GPT4-valid.txt", "workspace/tinystories_valid.uint16.bin")

    encode_to_uint16_bin(owt_tok, "data/owt_train.txt", "workspace/owt_train.uint16.bin")
    encode_to_uint16_bin(owt_tok, "data/owt_valid.txt", "workspace/owt_valid.uint16.bin")

    print("Done. Saved uint16 .bin files under workspace/.")

if __name__ == "__main__":
    main()

运行后输出结果如下所示:

在这里插入图片描述

OK,以上就是本次 BPE Tokenizer 作业的全部实现了

结语

本篇文章我们完整实现并验证了 CS336 Assignment 1 中要求的字节级 BPE Tokenizer,不仅覆盖了基础的 Unicode/UTF-8 处理、BPE 训练与编码解码流程,也结合实际数据规模,对预分词并行化与 BPE merge 阶段的性能瓶颈进行了针对性的工程优化

在实现过程中我们可以明显感受到 BPE 分词器并不是一个简单的算法题:当语料规模从示例数据扩展到 TinyStories、再到 OpenWebText 时,内存占用、并行粒度、进程生命周期管理以及数据结构设计,都会直接决定实现是否真正可用。很多在小数据集上看起来没问题的实现,在真实规模下都会迅速暴露问题

通过逐步从朴素实现演进到并行化预分词、再到增量更新的 merge 策略,本次作业很好地体现了 CS336 的核心理念:从零实现不是为了重复造轮子,而是为了理解现代语言模型系统在工程与算法层面所必须面对的真实约束

完成分词器实现后,我们已经具备了训练语言模型所需的最底层输入表示工具。接下来,后续作业将正式进入 Transformer Language Model 的实现与训练阶段,包括注意力机制、优化器、训练循环以及模型评估等内容,敬请期待🤗

源码下载链接

参考

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐