前言

学习斯坦福 CS336 课程,本篇文章记录课程第二讲:pytorch 和资源核算,记录下个人学习笔记,仅供自己参考😄

website:https://stanford-cs336.github.io/spring2025

video:https://www.youtube.com/playlist?list=PLoROMvodv4rOY23Y0BoGoBGgQ1zmU_MT_

materials:https://github.com/stanford-cs336/spring2025-lectures

course material:https://stanford-cs336.github.io/spring2025-lectures/?trace=var/traces/lecture_02.json

1. Overview

上次讲座我们对语言模型进行了概述,还谈到了分词,这将是第一次作业的前半部分,今天的讲座将实际构建一个模型,我们将讨论 PyTorch 中所需的底层功能,我们将从张量开始,构建模型、优化器,并构建一个训练循环,我们也将密切关注效率,特别是如何利用资源,包括内存和计算

2. Motivating question

为了激发思考,这些有一些问题需要一起来探讨下,这些问题可以用快速估算来回答

Question 1:在 15 万亿个 token 上,使用 1024 张 H100 显卡,训练一个 700 亿参数的密集型 Transformer 模型需要多久?

首先计算训练所需的总浮点运算次数(FLOPs):

total_flops = 6 * 70e9 * 15e12  # @inspect total_flops

计算公式是:6 * 参数量 * token 数量,这个公式是怎么来的呢?这就是我们将在本次讲座中讨论的内容

接着你可以查看 H100 的每秒浮点运算次数(FLOP/s):

assert h100_flop_per_sec == 1979e12 / 2

MFU,这是我们稍后会讨论的内容,我们先将其设为 0.5:

mfu = 0.5

然后你可以查看在此特定 MFU 下,你的硬件每天所能提供的浮点运算次数,也就是 1024 张 H100 一天的性能:

flops_per_day = h100_flop_per_sec * mfu * 1024 * 60 * 60 * 24  # @inspect flops_per_day

接着你只需要用训练模型所需的总浮点运算次数除以你每天能获得的浮点运算次数就能估算训练时长,大约是 144 天:

days = total_flops / flops_per_day  # @inspect days

在这里插入图片描述

这是非常简单的计算,稍后我们将更详细地探讨这些数字的来源,特别是 6 * 参数量 * token 数量 是怎么得来的

Question 2:如果你不采取什么特别的优化手段,使用 AdamW 优化器,你能在 8 张 H100 上训练的最大模型是多大?

H100 有 80GB 的 HBM 显存:

h100_bytes = 80e9  # @inspect h100_bytes

每个参数需要 16 字节,包括权重、梯度和优化器状态,我们稍后会详细解释这个数字是怎么来的:

bytes_per_parameter = 4 + 4 + (4 + 4)  # parameters, gradients, optimizer state  @inspect bytes_per_parameter

参数数量基本上就是总显存量除以每个参数所需要的字节数,这样算下来大约是 400 亿参数:

num_parameters = (h100_bytes * 8) / bytes_per_parameter  # @inspect num_parameters

在这里插入图片描述

这只是一个非常粗略的计算,因为它没有考虑激活值,而激活值取决于 batch size 和 sequence length,这部分我们不会细讲,但这对于作业 1 会非常重要

这可能是大家不太习惯做的事情,大家通常只是实现一个模型、训练它,然后看结果,但请记住,效率是关键。要做到高效,你必须清楚地知道你实际消耗了多少资源,因为当这些数字很大时,它们直接就变成了成本,而你肯定希望这笔费用越低越好,所以我们后面会更详细地讨论这些数字是如何得出的

3. What knowledge to take away

我们不会在本次讲座详细讲解 transformer,下次讲座 Tatsu 会讲解 transformer 的概念性概述,如果你还没了解过 transformer,有很多途径可以学习:

如果你做了作业 1,就肯定知道 transformer 是什么了,网上还有很多资料可以查阅,大家可以自行学习

在这里,我们将从更简单的模型入手,讲解并掌握以下知识:

  • 实现机制:简单直接(仅使用PyTorch)
  • 思维模式:资源核算
  • 直觉:大体轮廓(不涉及大模型)

我们会重点讲解原语(primitive)和资源核算(resource accounting)部分。本讲的实现部分将只涉及 PyTorch,理解 PyTorch 在基础层面是如何工作的

思维模式部分是关于资源核算,这不难,只是你必须动手实践。而直觉部分,我们现在只能大致讲一下,实际上这部分内容并不多。本次讲座更多是关于实现机制和思维模式的讲解

4. Memory accounting

我们先从内存核算(memory accounting)开始讲,然后再讲计算核算(compute accounting),我们将自上而下地逐步讲解

4.1 Tensors basics

OK,最好的起点是张量,张量是深度学习中存储一切的基础构建块,包括:参数、梯度、优化器状态、数据、激活值等等,你可以阅读很多关于它们的文档资料 [PyTorch docs on tensors]

你可能对如何创建张量非常熟悉了,这里有很多种创建张量的方式:

x = torch.tensor([[1., 2, 3], [4, 5, 6]])  # @inspect x
x = torch.zeros(4, 8)  # 4x8 matrix of all zeros @inspect x
x = torch.ones(4, 8)  # 4x8 matrix of all ones @inspect x
x = torch.randn(4, 8)  # 4x8 matrix of iid Normal(0, 1) samples @inspect x

你可以创建张量但不初始化,如果需要,也可以对张量使用一些特殊的初始化方法:

x = torch.empty(4, 8)  # 4x8 matrix of uninitialized values @inspect x
nn.init.trunc_normal_(x, mean=0, std=1, a=-2, b=2)  # @inspect x

这是关于张量的一些讨论

4.2 Tensors memory

接下来我们来聊聊内存以及张量会占用多少内存,我们可能感兴趣的是每个张量都是浮点数,表示浮点数的方式有很多种,默认是 float32

float32

[Wikipedia]

在这里插入图片描述

float32 有 32 位,其中 1 位用于符号位,8 位用于指数位,23 位用于尾数位,如上图所示。指数位提供了动态范围,而尾数位提供了不同的数值精度。

float32 也被称为 FP32 或 单精度,这是计算领域的黄金标准,当然也有些人将 float32 称为全精度

我们来看看内存占用,内存占用计算非常简单,它取决于张量中元素的数量以及每个元素的数据类型:

def get_memory_usage(x: torch.Tensor):
    return x.numel() * x.element_size()

x = torch.zeros(4, 8)  # @inspect x
assert x.dtype == torch.float32  # Default type
assert x.numel() == 4 * 8
assert x.element_size() == 4  # Float is 4 bytes
assert get_memory_usage(x) == 4 * 8 * 4  # 128 bytes

如果你创建一个 4x8 矩阵的 torch 张量,默认情况下,它会给你一个 float32 类型。矩阵大小是 4x8,元素数量是 32,每个元素大小是 4 个字节,内存用量就是元素个数乘以每个元素的大小,这样你就会得到 128 字节

为了给大家一点直观感受,我们拿 GPT-3 前馈层中的一个矩阵来举例:

assert get_memory_usage(torch.empty(12288 * 4, 12288)) == 2304 * 1024 * 1024  # 2.3 GB

结果是 2.3GB 大小,这还只是一个矩阵,所以如果我们采用默认的 float32 类型的话,这些矩阵占用的内存会变得非常大

所以很自然地,你会想把它们弄小点,以便使用更少的内存,而且事实证明如果你把它们弄小点,它们运行起来也会很快

所以另一种表示方式叫做 float16

float16

[Wikipedia]

在这里插入图片描述

正如它的名字所示,它是 16 位的,其中指数部分从 8 位缩减到 5 位,小数部分从 23 位缩减到 10 位,它也被称为 半精度,它能把内存用量减少一半

x = torch.zeros(4, 8, dtype=torch.float16)  # @inspect x
assert x.element_size() == 2

这都很好,只是 float16 的动态范围不太理想:

x = torch.tensor([1e-8], dtype=torch.float16)  # @inspect x
assert x == 0  # Underflow!

举个例子,如果你试图用 float16 来表示 1e-8 这样的数,它基本上会发生下溢,四舍五入变为 0,所以 float16 不太适合表示非常小的数,事实上,也不适合表示非常大的数

所以如果你用 float16 训练小模型可能问题不大,但对于大模型,当你有很多矩阵时就可能出现不稳定、下溢或上溢,然后就会出问题

不过有一件比较好的事情,那就是出现了另一种浮点数表示方式叫做 bfloat16(Brain Floating Point)

bfloat16

[Wikipedia]

在这里插入图片描述

bfloat16 诞生于 2018 年,由 Google 开发,目的是解决深度学习中的一个问题。我们实际上更看重动态范围,而不是这个小数部分的精度,所以 bfloat16 给指数部分分配了更多位,而给小数部分分配了更少位

它具有与 floa16 相同的内存占用,且拥有 FP32 的动态范围,听起来非常不错,但实际上,问题在于 resolution 它由小数部分决定,表现会比较差,但这对于深度学习来说没那么重要

现在,如果你尝试用 1e-8 和 BF16 创建一个张量,那么你会得到一个非零的值:

x = torch.tensor([1e-8], dtype=torch.bfloat16)  # @inspect x
assert x != 0  # No underflow!

在这里插入图片描述

你可以深入了解更多的细节,查看所有不同浮点运算的实际完整规格:

float32_info = torch.finfo(torch.float32)  # @inspect float32_info
float16_info = torch.finfo(torch.float16)  # @inspect float16_info
bfloat16_info = torch.finfo(torch.bfloat16)  # @inspect bfloat16_info

在这里插入图片描述

所以 BF16 基本上是你通常会用来进行计算的,因为它对于计算来说已经足够好了。但结果发现,对于存储优化器状态和权重参数,你仍然需要 FP32,否则你的训练会变得不稳定

现在我们有了一种叫做 FP8 的格式

fp8

https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html

在这里插入图片描述

这是由 NVIDIA 在 2022 年开发的,如上图所示,它有两种变体,取决于你想要更高的分辨率还是更大的动态范围,FP8 在 H100 上支持,但它在上一代产品上并不真正可用 [Micikevicius+ 2022]

从宏观上看,使用 FP32 进行训练是我们通常会做的,它相对稳定,但它需要更多内存。你也可以使用 FP8 或 BF16,但可能会出现一些数值不稳定的情形,总的来说,目前你大概率不会想在深度学习中使用 float16

你可以通过查看你 pipeline 中的特定环节来变得更加精细化,无论是前向传播、反向传播、优化器还是梯度累积,并真正弄清楚在这些特定环节你需要多低的精度,这就涉及到了混合精度训练

例如,有些人喜欢对注意力机制使用 float32 以确保它不会出错,而对于涉及矩阵乘法的简单前向传播,使用 BF16 是没问题的

5. Compute accounting

OK,刚才讲的是内存,现在我们来谈谈计算

5.1 tensors on gpus

计算能力显然取决于你的硬件是什么,默认情况下,张量是存储在 CPU 里面的:

x = torch.zeros(32, 32)
assert x.device == torch.device("cpu")

例如,如果你只是在 PyTorch 里面写 x = torch.zeros(32, 32),那么它就会被放到你的 CPU 上,存储在你的 CPU 内存里。当然这不太好,因为如果你不使用 GPU 的话,你的速度会慢好几个数量级

所以你需要在 PyTorch 中明确指定,要把数据移到 GPU 上:

在这里插入图片描述

上图展示了这一过程,图示左边是一个 CPU,它有内存(RAM),数据必须从这里移动到右边的 GPU 上,数据传输是有成本的,需要花费一些时间

所以无论何时你在 PyTorch 中定义一个张量时,你应该始终记住:它当前位于哪里?是位于 CPU 上还是 GPU 上?因为仅仅看变量名或代码本身,你并不总是能确定,如果你想谨慎地处理计算和数据移动,你就真的需要知道它在哪里

你可以做一些事情,比如在代码的各个位置上通过断言(assert)来检查它在哪里,确保它在正确的位置,符合你的预期

我们可以通过如下代码来获取你的 GPU 硬件信息:

import torch

num_gpus = torch.cuda.device_count()  # @inspect num_gpus
for i in range(num_gpus):
    properties = torch.cuda.get_device_properties(i)  # @inspect properties

memory_allocated = torch.cuda.memory_allocated()  # @inspect memory_allocated

在这里插入图片描述

从上图可知,这块 GPU 有 8GB 的高带宽内存,还提供了缓存大小等信息

请记住我们默认创建的张量 x 是在 CPU 上的,你可以通过 to() 这个 PyTorch 的通用函数来移动它:

y = x.to("cuda:0")
assert y.device == torch.device("cuda", 0)

你也可以直接在 GPU 上创建一个张量,这样你就完全不需要移动它了:

z = torch.zeros(32, 32, device="cuda:0")

如果一切顺利的话,我们还可以查看分配显存前后的情况:

new_memory_allocated = torch.cuda.memory_allocated()  # @inspect new_memory_allocated
memory_used = new_memory_allocated - memory_allocated  # @inspect memory_used
assert memory_used == 2 * (32 * 32 * 4)  # 2 32x32 matrices of 4-byte floats

在这里插入图片描述

前后差异应该正好是两个 32x32 的浮点矩阵,矩阵中的每个元素占 4 个字节,所以结果是 8192

5.2 tensor operations

OK,现在你的张量已经在 GPU 上了,接下来做什么呢?接下来很多操作是你完成第一次作业时会需要的,并且在深度学习应用中非常普遍

大多数张量都是通过对其他张量执行操作来创建的,每个操作都会产生一定的内存和计算开销,我们要确保理解这一点

5.2.1 tensor storage

那么首先来思考张量在 PyTorch 中到底是什么呢?张量是一个数学对象,而在 PyTorch 中,它们实际上是指向已分配内存的一些指针 [PyTorch docs]

在这里插入图片描述

x = torch.tensor([
    [0., 1, 2, 3],
    [4, 5, 6, 7],
    [8, 9, 10, 11],
    [12, 13, 14, 15],
])

如上图所示,假设你有一个 4x4 的矩阵,它在内存中实际看起来像是一个一维长数组,而张量对象本身拥有的是元数据,这些元数据指定了如何在数组中找到特定的地址

在当前示例中这些元数据是两个数字(两个步长),每个维度一个步长(stride),或者说,张量每个维度对应一个数字。在本例中,因为有两个维度所以对应步长 0 和步长 1 即 strides[0]strides[1]

assert x.stride(0) == 4

strides[0] 指的是,如果你在维度 0 要移动到下一行,也就是增加该维度的索引,你需要跳过 4 个元素,所以步长 0 即 strides[0] 是 4

assert x.stride(1) == 1

而要移动到下一列,你仅需要跳过 1 个元素,也就是步长 1 即 strides[1] 是 1

r, c = 1, 2
index = r * x.stride(0) + c * x.stride(1)  # @inspect index
assert index == 6

有了这些信息,要找到一个元素,例如索引为 (1,2) 的元素,只需要将各维度索引与对应步长相乘并相加就能得到它在数组中的索引,这里是 6

这基本上就是张量的底层机制,这点非常重要,因为你可以让多个张量共享同一块存储空间

5.2.2 tensor slicing

假设现在你有一个 2 行 3 列的矩阵,如下所示:

x = torch.tensor([[1., 2, 3], [4, 5, 6]])  # @inspect x

张量的很多操作实际上并不会创建一个新的张量,它们只会创建一个不同的视图,而不会进行复制,所以你必须确保你了解张量修改带来的影响

如果你开始修改一个张量,这会导致另一个张量也发生修改,举个例子:

def same_storage(x: torch.Tensor, y: torch.Tensor):
    return x.untyped_storage().data_ptr() == y.untyped_storage().data_ptr()

y = x[0]  # @inspect y
assert torch.equal(y, torch.tensor([1., 2, 3]))
assert same_storage(x, y)

如果你只获取 x 的第 1 行,并将它赋给张量 y,并通过 same_storage 函数来判断查看底层存储的两个张量是否共享同一块内存空间。答案是同一块,所以 x[0] 这个操作可能不会复制张量,它只是创建了一个视图

y = x[:, 1]  # @inspect y
assert torch.equal(y, torch.tensor([2, 5]))
assert same_storage(x, y)

同样你也可以获取第 1 列,这也不会复制张量

y = x.view(3, 2)  # @inspect y
assert torch.equal(y, torch.tensor([[1, 2], [3, 4], [5, 6]]))
assert same_storage(x, y)

你还可以调用 view 函数,它可以接受任何张量,并根据不同的维度来查看它,一个 2 行 3 列的张量可以被视为一个 3 行 2 列的张量,它也不会改变数据或进行任何复制

y = x.transpose(1, 0)  # @inspect y
assert torch.equal(y, torch.tensor([[1, 4], [2, 5], [3, 6]]))
assert same_storage(x, y)

你可以进行转置操作,这也不会复制数据

x[0][0] = 100  # @inspect x, @inspect y
assert y[0][0] == 100

就像我们之前说的,如果你开始修改 x,那么 y 实际上也会随之被修改,因为 xy 都是指向同一块底层存储空间的指针

x = torch.tensor([[1., 2, 3], [4, 5, 6]])  # @inspect x
y = x.transpose(1, 0)  # @inspect y
assert not y.is_contiguous()
try:
    y.view(2, 3)
    assert False
except RuntimeError as e:
    assert "view size is not compatible with input tensor's size and stride" in str(e)

需要注意的是,有些视图是连续的,这意味着如果你遍历这个张量就像只是平滑地遍历存储中的一维数组一样。而有些视图则不是连续的,特别是如果你对它进行转置(仅仅交换访问的步长,并没有真正在数据层面进行转置)后,如果你遍历这个张量会发现它其实是跳跃访问的(转置后原先的 strides[0] 变成了 strides[1]

如果你有一个非连续的张量 y,并且你试图以不同方式进一步创建其视图,这是行不通的

y = x.transpose(1, 0).contiguous().view(2, 3)  # @inspect y
assert not same_storage(x, y)
text("Views are free, copying take both (additional) memory and compute.")

所以在某些情况下,如果你有一个非连续的张量,你可以先把它变成连续的,然后你就可以应用任何你想要的视图操作了。在这种情况下,xy 将不会再具有相同的存储空间,因为在这种情况下,让它连续的这个操作(.contiguous())会创建一个新的副本

这是关于 tensor slice 切片相关的内容,要注意的是张量视图并不是真正的一个新副本,所以尽管使用它们,定义不同的视图变量可以让你的代码可读性更好,并且它们不会分配任何内存,但请记住 contiguous 或 reshape 等操作可能会创建一个新的副本

5.2.3 tensor elementwise

下面有一些操作会创建新的张量:

x = torch.tensor([1, 4, 9])
assert torch.equal(x.pow(2), torch.tensor([1, 16, 81]))
assert torch.equal(x.sqrt(), torch.tensor([1, 2, 3]))
assert torch.equal(x.rsqrt(), torch.tensor([1, 1 / 2, 1 / 3]))  # i -> 1/sqrt(x_i)

assert torch.equal(x + x, torch.tensor([2, 8, 18]))
assert torch.equal(x * 2, torch.tensor([2, 8, 18]))
assert torch.equal(x / 0.5, torch.tensor([2, 8, 18]))

特别是逐元素操作(element-wise)显然都会创建新的张量,因为你需要找个别的地方来存放这个新值

x = torch.ones(3, 3).triu()  # @inspect x
assert torch.equal(x, torch.tensor([
    [1, 1, 1],
    [0, 1, 1],
    [0, 0, 1]],
))

当你想创建一个注意力掩码时上面的三角运算会派上用场,这个在你的作业中会用到

5.2.4 tensor matmul

OK,接下来我们来聊聊矩阵乘法,深度学习的核心就是矩阵乘法

x = torch.ones(16, 32)
w = torch.ones(32, 2)
y = x @ w
assert y.size() == torch.Size([16, 2])

当你拿一个 16x32 的矩阵乘以一个 32x2 的矩阵时,你会得到一个 16x2 的矩阵

在这里插入图片描述

但通常来说,当我们在做机器学习应用时,所有操作你都会想在一个 batch 中完成,对于语言模型来说,这通常意味着对于 batch 中的每个序列,你都会想做点什么

x = torch.ones(4, 8, 16, 32)
w = torch.ones(32, 2)
y = x @ w
assert y.size() == torch.Size([4, 8, 16, 2])

一般来说,你不会只有一个矩阵,而是会有一个张量(x),它维度通常是 batch,sequence,然后是你想要处理的任何内容。在这种情况下,w 是你数据集中每个 token 对应的一个矩阵。当你拿这个四维张量和这个矩阵相乘时,实际发生的情况是对于每个批次、每次序列、每个 token,你都在进行两个矩阵的乘法,结果是你会为前两个维度中的每一个得到一个结果矩阵

5.3 tensor einops

接下来稍微跑下题,讲讲 einops

5.3.1 einops motivation

einops 的引入动机如下:

x = torch.ones(2, 2, 3)  # batch, sequence, hidden  @inspect x
y = torch.ones(2, 2, 3)  # batch, sequence, hidden  @inspect y
z = x @ y.transpose(-2, -1)  # batch, sequence, sequence  @inspect z

在 PyTorch 中,你会定义一些张量,然后看到上面这样的操作,z 的计算是 xy.transpose(-2, -1)batch matmul,也就是对每个 batch 做 2x3 @ 3x2,最终得到的 shape 是 2x2x2

其中 y.transpose(-2, -1)- 表示反向索引,但这真的很容易弄错,如果你做得好,会写很多注释,但是这样一来注释可能会和代码脱节,然后在调试时遇到麻烦,所以这里的解决方案是使用 Einops

这是受到爱因斯坦求和标记法的启发,而这个想法是,在操作张量时我们只需命名张量的所有维度,而不像以前那样本质上只依赖于索引 [Einops tutorial]

5.3.2 jaxtyping basics

有一个叫做 Jaxtyping 的库它能在类型中指定维度:

x = torch.ones(2, 2, 1, 3)  # batch seq heads hidden  @inspect x

通常在 pytorch 中,你只需要写你的代码,然后注释 x 的维度是 [batch, seq, heads, hidden]

x: Float[torch.Tensor, "batch seq heads hidden"] = torch.ones(2, 2, 1, 3)  # @inspect x

但如果你使用 jaxtyping 就会有上面这种表示法,你用字符串的形式写下维度 "batch seq heads hidden",这是一种稍微更自然的文档记录方式

5.3.3 einops einsum

Einsum 基本上就是加了 “料” 的矩阵乘法,它的可读性更好,这里有个例子:

x: Float[torch.Tensor, "batch seq1 hidden"] = torch.ones(2, 3, 4)  # @inspect x
y: Float[torch.Tensor, "batch seq2 hidden"] = torch.ones(2, 3, 4)  # @inspect y

我们定义了两个张量,每个张量有 3 个维度,分别代表 batch、seq、hidden,我们原来是通过如下的方式进行的 batch matmul:

z = x @ y.transpose(-2, -1)  # batch, sequence, sequence  @inspect z

现在采取的替代方法是:

z = einsum(x, y, "batch seq1 hidden, batch seq2 hidden -> batch seq1 seq2")  # @inspect z

你可以先写下这两个张量各个维度的名称,例如 batch seq1 hidden, batch seq2 hidden,然后写下哪些维度应该出现在输出结果中,例如 batch seq1 seq2,注意这里并没有写隐藏层维度 hidden,任何未在输出结果中命名的维度都会被求和

一旦你习惯了这种方法,它会非常有用,如果你是第一次看到这种方法,它可能看起来比较奇怪,但相信我,一旦你习惯了它,会比 (-2, -1) 这种方式更好

如果你更灵活,还可以使用 ... 来表示对任意数量的维度进行广播:

z = einsum(x, y, "... seq1 hidden, ... seq2 hidden -> ... seq1 seq2")  # @inspect z

在这种情况下,我们可以用 ... 代替写 batch 维度

5.3.4 einops reduce

我们接着来看看 reduce 操作

reduce 操作只作用于一个张量,它会对张量的一个或多个维度进行聚合:

x: Float[torch.Tensor, "batch seq hidden"] = torch.ones(2, 3, 4)  # @inspect x

如果你需要对张量 x 的最后一个维度求和,以前的方式你可能会使用 mean

y = x.mean(dim=-1)  # @inspect y

现在我们可以使用 reduce

y = reduce(x, "... hidden -> ...", "sum")  # @inspect y

在这里插入图片描述

可以看到前后维度的变化,hidden 维度消失了,这意味着你正在聚合那个维度。你也可以检查下 reduce 在这里确实是有效的

5.3.5 einops rearrange

OK,关于 einops 的最后一个例子可能是,有时在一个张量中,一个维度实际上代表多个维度,并且你想解包它,然后操作其中一个,再把它打包回去

x: Float[torch.Tensor, "batch seq total_hidden"] = torch.ones(2, 3, 8)  # @inspect x

在这种情况下,假设你有 batch、seq 和 8 维的 total_hidden 三个维度,而 8 维度的 total_hidden 向量实际上是一个 heads * hidden1 即头数乘以某个隐藏层维度的展平表示

然后你有一个向量需要对那个隐藏维度进行操作:

w: Float[torch.Tensor, "hidden1 hidden2"] = torch.ones(4, 4)

你可以使用 einops 非常优雅地做到这一点,通过调用 rearrange 函数:

x = rearrange(x, "... (heads hidden1) -> ... heads hidden1", heads=2)  # @inspect x

你可以回想一下我们之前看过 view,它有点像 view,只不过是一个更花哨的版本,它基本上访问的是相同的数据,但方式不同。在这里实际上要把 (heads hidden1) 这个维度给它分解成两个维度,并且你必须在这里指定头数,因为有多种方式可以将一个数字分成两个

x = einsum(x, w, "... hidden1, hidden1 hidden2 -> ... hidden2")  # @inspect x
x = rearrange(x, "... heads hidden2 -> ... (heads hidden2)")  # @inspect x

对于给定的 x,你可以使用 einsum 执行你的转换,... hidden 1 对应于 xhidden1 hidden2 对应于 w。并且你还可以 rearrange 回来,这正是分解操作的逆过程,也就是说你有两个维度现在需要把它们组合成一个,这就是一个展平操作,在确保其他维度保持不变的情况下进行的

5.4 tensor operation flops

现在我们来谈谈张量操作的计算成本,我们上面介绍了一系列张量操作,它们的成本是多少呢?

张量操作本质上是浮点操作,而浮点操作(floating-point operation,FLOP)是指任何涉及浮点数的运算,比如加法或乘法,而这些是在计算 FLOP 数量时会主要考虑的

当你说 flops 时,其实不清楚你具体指什么,你可能指带有小写字母 s 的 FLOPs,它代表 浮点操作次数,衡量的是你完成的计算量。或者你可能指的是带有大写字母 S 的 FLOPS,它代表 每秒浮点运算次数,衡量的是硬件的速度,在这门课上我们不会使用带有大写字母 S 的写法,因为那样会令人困惑,而是只写 /s 表示每秒浮点运算次数

为了让大家对 FLOPs 有个直观了解,下面列举了一些常见模型的 FLOPs:

而 A100 的峰值性能为 312 teraFLOP/s [spec]

assert a100_flop_per_sec == 312e12

而 H100 的峰值性能带稀疏性时为 1979 teraFLOP/s [spec],不带稀疏性时为 50%

assert h100_flop_per_sec == 1979e12 / 2

如果你查看 NVIDIA H100 的规格书时你会发现,如果你使用 FP32,性能实际上非常差,它比你使用 FP16 要差好几个数量级,如果你愿意降到 FP8,那么速度会更快:

在这里插入图片描述

注意规格书的左下角有个说明,*With Sparsity 意味着带有稀疏性,通常我们在这门课里讲的很多矩阵都是稠密的,而对于稠密矩阵而言你只能得到其一半的性能

所以你可以进行计算了,8 个 H100 用两周的 FLOPs 是多少呢?

total_flops = 8 * (60 * 60 * 24 * 7) * h100_flop_per_sec  # @inspect total_flops

在这里插入图片描述

这是一周的浮点运算次数,大概是 4.788e21,你可以用其他模型的计算量来衡量 FLOPs 的数量,比如使用 8 块 H100 训练一个 GPT3 大概需要 3.14e23 / 4.788e21 ≈ 65.6 个周,也就是 459 天的样子

我们来看一个简单的例子,我们不会涉及 Transformer 模型,但即使是线性模型(Linear Model)也能为我们提供许多基础 block 构建方法和直观理解

假设我们有 B 个点,每个点的维度是 D 维,而线性模型只是将每个 D 维向量映射成一个 K 维向量,我们来设定一下点数 B、维度 D 以及输出的维度 K

if torch.cuda.is_available():
    B = 16384  # Number of points
    D = 32768  # Dimension
    K = 8192   # Number of outputs
else:
    B = 1024
    D = 256
    K = 64

然后创建我们的数据矩阵 X 和权重矩阵 W,并计算得到 Y

device = get_device()
x = torch.ones(B, D, device=device)
w = torch.randn(D, K, device=device)
y = x @ w

线性模型本质上就是一种映射,所以这里没什么太复杂的操作,问题来了,这总共需要多少次浮点运算(FLOPs)呢?

要计算这个,你可以这样想:当你进行矩阵乘法时基本上对于每一个 (i,j,k) 组合,你都需要将 x[i][j]w[j][k] 两个数相乘,并且还需要将乘积加到总和中,所以实际的 FLOPs 总数为:

actual_num_flops = 2 * B * D * K  # @inspect actual_num_flops

也就是 2 乘以所有相关维度的乘积,所以如果你在进行矩阵乘法,这一点你应该记住:浮点运算次数(FLOPs)是三个维度乘积的 2 倍

Note:在 2x2 和 2x3 的矩阵乘法运算中,每个结果元素(6 个元素)需要进行 2 次乘法运算和 1 次求和运算(3 次 FLOP),因此需要 18 次 FLOP,而不是 24 次(2x2x3 的 2 倍),因此正确的浮点运算次数应该等于 (2D - 1) * B * K

其他操作的浮点运算次数通常与矩阵或张量的大小呈线性关系,总的来说,在深度学习中,对于足够大的矩阵,你遇到的其他任何操作都不如矩阵乘法耗时,这就是为什么我们主要只关注模型执行的矩阵乘法。

当然,在某些情况下,如果你的矩阵足够小,那么其他操作的成本就会开始占据主导地位,但这并不是你想要处于的一种有利状态,因为硬件是为大型矩阵乘法量身定制的。这有点像是在兜圈子,但在本次课程中我们只考虑那些矩阵乘法是主要计算开销的模型

对于上面这个特定的线性模型,前向传播所需的浮点运算次数(FLOPs)是 2 * B * D * K,实际上这可以推广到 Transformer 模型,但还需要考虑序列长度和其他因素,如果你的序列长度不太大,这个估计大致是准确的

这仅仅是浮点运算的次数,那实际运行的时间是多少呢?🤔 这大概才是你真正关心的问题,那么我们来测一下时间:

def time_matmul(a: torch.Tensor, b: torch.Tensor) -> float:
    """Return the number of seconds required to perform `a @ b`."""

    # Wait until previous CUDA threads are done
    if torch.cuda.is_available():
        torch.cuda.synchronize()

    def run():
        # Perform the operation
        a @ b

        # Wait until CUDA threads are done
        if torch.cuda.is_available():
            torch.cuda.synchronize()

    # Time the operation `num_trials` times
    num_trials = 5
    total_time = timeit.timeit(run, number=num_trials)

    return total_time / num_trials

actual_time = time_matmul(x, w)  # @inspect actual_time
actual_flop_per_sec = actual_num_flops / actual_time  # @inspect actual_flop_per_sec

在这里插入图片描述

这里有一个 time_matmul 的函数,它会运行 5 次矩阵乘法操作,并统计这 5 次的平均耗时,现在我们可以得到 1 次矩阵乘法实际运行时间了,也就是 actual_time,这个矩阵乘法操作花了 0.0028s 以及实际的每秒浮点运算次数(FLOP/s)是 1.17e10

def get_promised_flop_per_sec(device: str, dtype: torch.dtype) -> float:
    """Return the peak FLOP/s for `device` operating on `dtype`."""
    if not torch.cuda.is_available():
        text("No CUDA device available, so can't get FLOP/s.")
        return 1
    properties = torch.cuda.get_device_properties(device)

    if "A100" in properties.name:
        # https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf")
        if dtype == torch.float32:
            return 19.5e12
        if dtype in (torch.bfloat16, torch.float16):
            return 312e12
        raise ValueError(f"Unknown dtype: {dtype}")

    if "H100" in properties.name:
        # https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet")
        if dtype == torch.float32:
            return 67.5e12
        if dtype in (torch.bfloat16, torch.float16):
            return 1979e12 / 2  # 1979 is for sparse, dense is half of that
        raise ValueError(f"Unknown dtype: {dtype}")

    raise ValueError(f"Unknown device: {device}")

promised_flop_per_sec = get_promised_flop_per_sec(device, x.dtype)  # @inspect promised_flop_per_sec

现在我们可以将其与 A100 和 H100 的规格书进行比较。我们查看规格书可以知道 FLOP/s 取决于数据类型,对于 H100 来说,我们看到的承诺的每秒浮点运算次数是 67.5e12,这是针对 float32 数据类型,正如我们之前看到的一样

额外提一句,有一个非常有用的概念叫做 模型浮点运算利用率(Model Flops Utilization,MFU),也就是实际浮点次数除以承诺的浮点运算次数

mfu = actual_flop_per_sec / promised_flop_per_sec  # @inspect mfu

所以你会听到人们谈论他们的 MFU,而 MFU 大于 0.5 通常被认为是好的,假设你的 MFU 只有 0.05,那情况就会非常糟糕。通常 MFU 很难接近 90% 或 100%,因为它在某种程度上忽略了各种通信和开销,它只是字面意义上的浮点计算。如果矩阵乘法占主导地位,MFU 通常会高得多

你也可以用 BF16 操作进行同样的操作:

x = x.to(torch.bfloat16)
w = w.to(torch.bfloat16)
bf16_actual_time = time_matmul(x, w)  # @inspect bf16_actual_time
bf16_actual_flop_per_sec = actual_num_flops / bf16_actual_time  # @inspect bf16_actual_flop_per_sec
bf16_promised_flop_per_sec = get_promised_flop_per_sec(device, x.dtype)  # @inspect bf16_promised_flop_per_sec
bf16_mfu = bf16_actual_flop_per_sec / bf16_promised_flop_per_sec  # @inspect bf16_mfu

所以总结一下:

  • 矩阵乘法在计算中占主导地位,通常的经验法则是:浮点运算次数是维度乘积的两倍
  • 每秒浮点运算次数(FLOP/s)取决于你的硬件以及数据类型,你拥有的硬件越好,这个数值就越高,数据类型越小,通常速度就越快
  • MFU 是一个有用的概念,可以用来衡量你如何有效地榨取你的硬件性能

5.5 gradient basics

下面我们来聊聊梯度,我们之前只关注了矩阵乘法,换句话说,基本上是前向传播的浮点运算次数(FLOPs),但还有一部分计算量来自于计算梯度,我们想弄清楚这部分 FLOPs 有多少

那我们来看一个简单的例子,一个简单的线性模型:y = 0.5 (x * w - 5)^2,你取一个线性模型的预测值,然后计算它与 5 的均方差(MSE)。这并不是一个很有趣的损失函数,但它对于理解梯度计算很有启发性

x = torch.tensor([1., 2, 3])
w = torch.tensor([1., 1, 1], requires_grad=True)  # Want gradient
pred_y = x @ w
loss = 0.5 * (pred_y - 5).pow(2)

在前向传播时,你有输入 x 和权重 w,你想计算相对于它们的梯度。通过线性乘积得到预测结果 pred_y,然后计算损失

loss.backward()
assert loss.grad is None
assert pred_y.grad is None
assert x.grad is None
assert torch.equal(w.grad, torch.tensor([1, 2, 3]))

在反向传播时,你只需要调用 loss.backward() 方法,在这种情况下,梯度,也就是附着在各个张量上的变量(.grad)就是你想要的

5.6 gradient flops

大家以前都在 PyTorch 中计算过梯度了,那我们来看看计算梯度需要多少浮点运算次数

我们来看一个稍微复杂一点的模型:

if torch.cuda.is_available():
    B = 16384  # Number of points
    D = 32768  # Dimension
    K = 8192   # Number of outputs
else:
    B = 1024
    D = 256
    K = 64
device = get_device()
x = torch.ones(B, D, device=device)
w1 = torch.randn(D, D, device=device, requires_grad=True)
w2 = torch.randn(D, K, device=device, requires_grad=True)

h1 = x @ w1
h2 = h1 @ w2
loss = h2.pow(2).mean()

现在是一个两层线性模型,其中有输入 x,维度是 (B, D),乘以 w1,维度是 D, D,这是第一层。然后你取隐藏层的激活值 h1,再通过另一个线性层 w2 得到一个 K 维向量,然后计算损失。总的数据流是 x --w1--> h1 --w2--> h2 -> loss

回顾一下我们在前向传播时的计算量,你需要做的是:

  • Multiply x[i][j] * w1[j][k]
  • Add to h1[i][k]
  • Multiply h1[i][j] * w2[j][k]
  • Add to h2[i][k]

所以总的计算量依旧是 2 乘以所有维度的乘积:

num_forward_flops = (2 * B * D * D) + (2 * B * D * K)  # @inspect num_forward_flops

换句话说,在前向传播这种情况下 FLOPs 是总参数量的两倍

那反向传播呢?这部分会稍微复杂一些,在反向传播中,你需要计算很多梯度,包括:

  • h1.grad = d loss / d h1
  • h2.grad = d loss / d h2
  • w1.grad = d loss / d w1
  • w2.grad = d loss / d w2

也就是说,损失函数 d loss 关于 h1h2w1w2 这些变量的每一个的梯度你都需要计算,那么计算这些需要多少时间呢?

num_backward_flops = 0  # @inspect num_backward_flops

我们现在只看 w2,所有与 w2 相关的计算,你都可以通过链式法则来完成:w2.grad[j,k] = sum_i h1[i,j] * h2.grad[i,k],所以 d loss / d w2 的梯度就是你将 h1 求和乘以损失函数关于 h2 的梯度,这只是针对 w2 的链式法则

assert w2.grad.size() == torch.Size([D, K])
assert h1.size() == torch.Size([B, D])
assert h2.grad.size() == torch.Size([B, K])

而且所有梯度的大小都相同,与对应的底层向量一样,所以这计算起来本质上看起来像一个矩阵乘法,因此同样的方法也适用,也就是说,它是所有维度的乘积的两倍:

num_backward_flops += 2 * B * D * K  # @inspect num_backward_flops

这只是关于 w2 的梯度,我们还需要计算相对于 h1 的梯度,因为我们必须一直反向传播到 w1 才行,h1 的梯度会是 h2 的梯度乘以 w2:h1.grad[i,j] = sum_k h2.grad[i,k] * w2[j,k]

assert h1.grad.size() == torch.Size([B, D])
assert w2.size() == torch.Size([D, K])
assert h2.grad.size() == torch.Size([B, K])
num_backward_flops += 2 * B * D * K  # @inspect num_backward_flops

所以结果基本上也看起来像矩阵乘法,而且计算每个梯度的浮点运算次数是一样的,当你把二者相加时,这只是为了计算 w2

你对 w1 也做同样的事情,而它有 D * D 个参数,当你把它们全部加起来时,对于 w2,计算量是 4 * B * D * K,对于 w1,它是 4 * B * D * D,因为 w1 的维度是 D * D:

num_backward_flops += (2 + 2) * B * D * D  # @inspect num_backward_flops

我们来看一下图示解释,它来自于一篇博客文章 [article]

在这里插入图片描述

上图演示的是线性层在反向传播时两类梯度(权重、输入)的来源和代价,关于具体的细节,大家可以自己琢磨下

从更高层面来说,前向传播的计算量是参数数量的 2 倍,反向传播的计算量是参数数量的 4 倍,我们可以通过上面的链式法则轻松计算出来

简单总结一下,对于上面这个特定的两层线性模型:

  • 前向传播的计算量是数据点数量乘以参数数量的 2 倍
  • 反向传播的计算量则是数据点数量乘以参数数量的 4 倍
  • 总计算量是数据点数量乘以参数数量的 6 倍

这就解释了为什么我们在一开始提出那个启发性问题时,答案中会有个 6。那么,这只是针对一个简单的线性模型,但事实证明,对于许多模型来说,这基本上是计算量的主要部分

6. Models

到目前为止,关于资源核算的部分我们基本上就讲完了,我们回顾了张量,讨论了张量上的一些计算,并且看了张量在进行各种操作时会消耗多少浮点运算,现在我们开始构建不同的模型

6.1 module parameters

在 PyTorch 中,参数被存储为 nn.Parameter 对象:

input_dim = 16384
output_dim = 32

w = nn.Parameter(torch.randn(input_dim, output_dim))
assert isinstance(w, torch.Tensor)  # Behaves like a tensor
assert type(w.data) == torch.Tensor  # Access the underlying tensor

我们来谈谈参数的初始化,你的 w 参数是输入维度乘以隐藏维度的矩阵,接着我们放入一个输入,然后让它通过 w(模型)得到输出:

x = nn.Parameter(torch.randn(input_dim))
output = x @ w  # @inspect output
assert output.size() == torch.Size([output_dim])

在这里插入图片描述

当你这样做时,如果你看输出,你会得到一些相当大的数字,这是因为数值增长的量级本质上与隐藏维度的平方根成正比

当你拥有大模型时,这将会导致爆炸(数值发散),并且训练会变得非常不稳定。所以通常你想做的是以一种方式进行初始化,这种方式对隐藏维度具有某种不变性或者至少保证它不会爆炸

一个简单的方法就是重新缩放,通过除以输入数量的平方根:

w = nn.Parameter(torch.randn(input_dim, output_dim) / np.sqrt(input_dim))
output = x @ w  # @inspect output

在这里插入图片描述

我们只是简单地将 w 除以输入维度的平方根,现在当你让输入 x 通过 w 得到输出时,你会发现你得到的结果在某个值附近,是稳定的,它实际上会集中在类似 Normal(0, 1) 的分布附近,

这基本上在深度学习文献中已经被广泛地探索过了,它被称为 Xavier 初始化 [paper][stackexchange]

如果你想要更加安全,你不会信任正态分布,因为它的尾部是无界的,你可能会说:“我把它截断到 [-3, 3]”,这样就不会得到异常大的值

6.2 custom model

那么我们来构建一个简单的模型,它的维度是 D,有两层:

class Cruncher(nn.Module):
    def __init__(self, dim: int, num_layers: int):
        super().__init__()
        self.layers = nn.ModuleList([
            Linear(dim, dim)
            for i in range(num_layers)
        ])
        self.final = Linear(dim, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Apply linear layers
        B, D = x.size()
        for layer in self.layers:
            x = layer(x)

        # Apply final head
        x = self.final(x)
        assert x.size() == torch.Size([B, 1])

        # Remove the last dimension
        x = x.squeeze(-1)
        assert x.size() == torch.Size([B])

        return x

D = 64  # Dimension
num_layers = 2
model = Cruncher(dim=D, num_layers=num_layers)

Cruncher 是一个自定义模型,是一个有 num_layers 层的深度线性网络,而且每一层都是一个线性模型,本质上就是一个矩阵乘法

def get_num_parameters(model: nn.Module) -> int:
    return sum(param.numel() for param in model.parameters())

assert param_sizes == [
    ("layers.0.weight", D * D),
    ("layers.1.weight", D * D),
    ("final.weight", D),
]
num_parameters = get_num_parameters(model)
assert num_parameters == (D * D) + (D * D) + D

这个模型的参数有三层,对于第一层它是一个 D * D 矩阵,第二层也是一个 D * D 矩阵,然后还有一个最终层,如果我们来计算这个模型的参数量,那么它将是 D * D + D * D + D

接着我们要把它移到 GPU 上,这样它就能跑得更快,我们会生成一些随机数据,然后送入模型中进行处理:

device = get_device()
model = model.to(device)

B = 8  # Batch size
x = torch.randn(B, D, device=device)
y = model(x)
assert y.size() == torch.Size([B])

前向传播过程就是依次通过各层,最后应用 head 层(final

6.3 note about randomness

有了这个模型后,我们要用这个模型来做一些操作,有一个题外话,随机性有时可能会有点烦人。比如,如果你试图重现一个 bug,随机性出现在很多地方:初始化、dropout 等等,所以这边建议你传入一个固定的随机种子,这样你就能重新复现你的问题

为每个随机来源设置不同的随机种子是一个很好的做法,因为这样你就可以固定初始化或者其它包含随机性的操作,尤其在你进行代码调试时,这是非常关键的

# Torch
seed = 0
torch.manual_seed(seed)

# NumPy
import numpy as np
np.random.seed(seed)

# Python
import random
random.seed(seed)

有很多地方可以使用随机性,你需要清楚你正在使用哪种,如果你想保险起见,你可以将这些都设置一个随机种子

6.4 data loading

数据加载这部分我们快速过下,这对你的作业会很有用,在语言建模中,数据通常只是一个整数序列,是由分词器(tokenizer)输出的

你可以将它们序列化成 numpy 数组:

orig_data = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=np.int32)
orig_data.tofile("data.npy")

一件有用的事情是你不希望一次性将所有数据加载到内存中,因为它们都太大了,例如 LLaMA 的数据有 2.8T

你可以通过使用 np.memmap 这个方便的函数来假装加载它,它本质上为你提供了一个映射到文件的变量:

data = np.memmap("data.npy", dtype=np.int32)
assert np.array_equal(data, orig_data)

当你尝试访问数据时,它实际上会按需加载文件。然后使用它你可以创建一个数据加载器:

B = 2  # Batch size
L = 4  # Length of sequence
x = get_batch(data, batch_size=B, sequence_length=L, device=get_device())
assert x.size() == torch.Size([B, L])

其中 get_batch 是从你的原始数据中采样 batch 大小的数据,时间关系,我们跳过这部分

6.5 optimizer

接下来我们来谈谈优化器(optimizer),上面我们已经定义了我们的模型,如下所示,现在该定义优化器了

B = 2
D = 4
num_layers = 2
model = Cruncher(dim=D, num_layers=num_layers).to(get_device())

这里有很多的优化器,我们将介绍其中一些背后的思想:

  • momentum = SGD + exponential averaging of grad
  • AdaGrad = SGD + averaging by grad^2
  • RMSProp = AdaGrad + exponentially averaging of grad^2
  • Adam = RMSProp + momentum

最常见的可能就是随机梯度下降(Stochastic Gradient Descent),SGD 优化器,计算一批数据的梯度,然后毫不犹豫地朝着那个方向迈出一步

有一个叫做动量(momentum)的想法,它源自经典的 Nesterov 优化方法,你会保留一个梯度的滑动平均值,然后根据这个滑动平均值来更新参数,而不是根据你当前的瞬时梯度

然后是 AdaGrad 算法,它根据历史梯度平方的平均值来缩放梯度。还有 RMSProp,它是 AdaGrad 的一个改进版本,它使用指数滑动平均,而不是简单的平均

最后是 Adam 算法,它是在 2014 年提出的,它本质上是结合了 RMSProp 和动量法,所以说在 Adam 中,你会同时维护梯度的滑动平均值和梯度平方的滑动平均值

在作业 1 中需要实现 Adam,这里就不演示了,作为替代,我们来实现 AdaGrad 算法 https://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf

class AdaGrad(torch.optim.Optimizer):
    def __init__(self, params: Iterable[nn.Parameter], lr: float = 0.01):
        super(AdaGrad, self).__init__(params, dict(lr=lr))

    def step(self):
        for group in self.param_groups:
            lr = group["lr"]
            for p in group["params"]:
                # Optimizer state
                state = self.state[p]
                grad = p.grad.data

                # Get squared gradients g2 = sum_{i<t} g_i^2
                g2 = state.get("g2", torch.zeros_like(grad))

                # Update optimizer state
                g2 += torch.square(grad)
                state["g2"] = g2

                # Update parameters
                p.data -= lr * grad / torch.sqrt(g2 + 1e-5)

optimizer = AdaGrad(model.parameters(), lr=0.01)

state = model.state_dict()  # @inspect state

在 PyTorch 中实现一个优化器的方法是重写优化器类

x = torch.randn(B, D, device=get_device())
y = torch.tensor([4., 5.], device=get_device())
pred_y = model(x)
loss = F.mse_loss(input=pred_y, target=y)
loss.backward()

optimizer.step()
state = model.state_dict()  # @inspect state

我们先定义一些数据,计算前向传播得到损失,然后计算梯度,注意当你调用 optimizer.step() 时,优化器就在这里真正起作用了

def step(self):
    for group in self.param_groups:
        lr = group["lr"]
        for p in group["params"]:
            # Optimizer state
            state = self.state[p]
            grad = p.grad.data
            
            pass

step 函数中看起来就像是你的参数被分组了,例如,你可以按层分组,比如第 0 层、第 1 层以及最后的权重。你可以访问一个状态(state),它是一个字典,键是参数,值是你想要存储的任何优化器状态信息

def step(self):
    for group in self.param_groups:
        lr = group["lr"]
        for p in group["params"]:
            # Optimizer state
            state = self.state[p]
            grad = p.grad.data

            # Get squared gradients g2 = sum_{i<t} g_i^2
            g2 = state.get("g2", torch.zeros_like(grad))

            # Update optimizer state
            g2 += torch.square(grad)
            state["g2"] = g2

            # Update parameters
            p.data -= lr * grad / torch.sqrt(g2 + 1e-5)

假设这个参数的梯度已经通过反向传播计算好了,现在你就可以做一些事情了,比如,在 AdaGrad 算法中你会存储梯度平方的累加和

你可以获取 g2 变量,并根据当前梯度的平方来更新它,接着对梯度进行逐元素的平方操作,然后把它放回状态(state)中

优化器负责更新参数:

# Update parameters
p.data -= lr * grad / torch.sqrt(g2 + 1e-5)

这就是你更新参数的方式,即学习率乘以梯度再除以缩放值,这个状态 state 在优化器的多次调用之间会保留

在优化器的步骤结束时,你可以释放内存:

optimizer.zero_grad(set_to_none=True)

这在你了解模型并行时会更加重要

接下来我们谈谈优化器状态的内存需求:

# Parameters
num_parameters = (D * D * num_layers) + D  # @inspect num_parameters
assert num_parameters == get_num_parameters(model)

你需要知道,这个模型的参数数量是 D * D 乘以层数,再加上最终输出层的 D,也就是 (D * D * num_layers) + D

# Activations
num_activations = B * D * num_layers  # @inspect num_activations

激活值的数量这一点我们之前没有计算过,但对于这个简单的模型来说,计算起来相当容易,它是 B * D 乘以模型的层数。对于每一层、每一个数据点、每一个维度,你都必须存储激活值

# Gradients
num_gradients = num_parameters  # @inspect num_gradients

对于梯度,它的数量与参数数量相同

# Optimizer states
num_optimizer_states = num_parameters  # @inspect num_optimizer_states

优化器状态的数量,对于 AdaGrad 而言,我们需要存储梯度平方,相当于参数数量

# Putting it all together, assuming float32
total_memory = 4 * (num_parameters + num_activations + num_gradients + num_optimizer_states)  # @inspect total_memory

在这里插入图片描述

综合来看,总内存量(假设使用 FP32 数据类型,即 4 字节)是 4 乘以参数数量、激活值数量、梯度数量以及优化器状态数量的总和,这样我们就得到了一个数值,这里是 496,这是一个相当简单的计算

在作业 1 中,需要对 Transformer 模型进行这项计算,这会稍微复杂一些,因为它不仅涉及矩阵乘法,还有很多其他矩阵、注意力机制等等内容,但计算的通用形式是相同的,你有参数、激活值、梯度以及优化器状态

flops = 6 * B * num_parameters  # @inspect flops

这个模型所需的浮点运算量是 token 数量的 6 倍,或者说是数据点数量乘以参数数量的 6 倍

以上就是对这个特定模型的资源计算,顺便说一下,如果你好奇如何为 Transformer 模型计算这些值,你可以参考以下这些文章:

  • Blog post describing memory usage for Transformer training:[article]
  • Blog post descibing FLOPs for a Transformer:[article]

6.6 train loop

到目前为止,我们讲了如何构建张量,然后构建了一个非常小的模型,我们讲了优化器,以及需要多少内存和多少计算资源

下面我们来快速进行一次训练:

def train(name: str, get_batch,
          D: int, num_layers: int,
          B: int, num_train_steps: int, lr: float):
    model = Cruncher(dim=D, num_layers=0).to(get_device())
    optimizer = SGD(model.parameters(), lr=0.01)

    for t in range(num_train_steps):
        # Get data
        x, y = get_batch(B=B)

        # Forward (compute loss)
        pred_y = model(x)
        loss = F.mse_loss(pred_y, y)

        # Backward (compute gradients)
        loss.backward()

        # Update parameters
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

D = 16
true_w = torch.arange(D, dtype=torch.float32, device=get_device())
def get_batch(B: int) -> tuple[torch.Tensor, torch.Tensor]:
    x = torch.randn(B, D).to(get_device())
    true_y = x @ true_w
    return (x, true_y)

train("simple", get_batch, D=D, num_layers=0, B=4, num_train_steps=10, lr=0.01)

这是典型的训练循环,在这里你需要定义模型、定义优化器,然后获取数据,前向传播、反向传播、更新参数,这没有什么好说的

6.7 checkpointing

关于检查点(checkpointing)有一点需要说明:训练语言模型需要很长时间,你可能会在某个时候遇到崩溃,因此,你需要定期将模型保存到磁盘,这样你就不会丢失你的训练进度

model = Cruncher(dim=64, num_layers=3).to(get_device())
optimizer = AdaGrad(model.parameters(), lr=0.01)

checkpoint = {
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, "model_checkpoint.pt")

需要明确的是,你需要保存的东西包括模型和优化器,然后你就可以加载它们了:

loaded_checkpoint = torch.load("model_checkpoint.pt")

6.8 mixed precision training

最后一个注意事项,正如我们之前提到的,关于训练中的混合精度。数据类型(float32,bfloat16,fp8)的选择有不同的权衡,如果精度更好,它更准确、更稳定,但成本更高,低精度则相反

正如我们之前提到的,默认推荐使用 float32,但你可以尝试使用 bfloat16 甚至 fp8。你可以对第五个前向传播使用较低的精度,但对其余部分使用 float32。这个想法可以追溯到 2017 年 [Micikevicius+ 2017],那时人们正在探索混合精度训练

PyTorch 提供了一些工具,可以自动进行混合精度训练(automatic mixed precision,AMP),因为手动指定模型的哪些部分需要使用什么精度可能会有点麻烦,感兴趣的话你可以查看下面的一些资料:

现在还有些论文 [Peng+ 2023] 表明你可以全程使用 FP8,当然,一个挑战是,当精度较低时,数值会变得非常不稳定,但你可以采用各种技巧来控制模型在训练过程中的数值稳定性,这样你就不会陷入那些不好的状况

所以,系统设计和模型架构在这里是相辅相成的,因为很多模型设计都受到硬件的制约,你在设计模型时要考虑到这一点。即使是上次我们提到的 Transformer 模型,也是因为有了 GPU 才能实现

现在我们注意到,某些芯片具有这样的特性:如果使用较低精度,比如像 INT4 这样的精度,那么你就能获得巨大的速度提升,你的模型也会更高效

还有另外一件事,我们以后会谈到,那就是通常来说你会在训练模型时使用更常规的浮点精度,但到了推理阶段,你就可以大胆尝试了

你把你训练好的模型拿来,然后可以对它进行量化,并从非常激进的量化中获得很多性能提升。但不知何故,用低精度训练要困难得多,但一旦你有了训练好的模型,将其转换为低精度就容易得多

7. Summary

最后总结一下,我们讨论了训练模型的不同基本要素,从张量一直到训练循环的整个构建过程,我们讨论了一些简单模型的内存消耗和 FLOPs 计算,希望你完成了作业 1 之后,这些概念都能真正牢固掌握,因为你会将这些概念应用于实际的 Transformer 模型

OK,以上就是本次讲座的全部内容了

结语

第二讲我们主要讲了资源的利用,主要包括内存和计算两部分,并从张量开始构建了模型、优化器,完成了一个训练循环。

内存核算小节我们学习了张量,还了解了各种表示浮点数的数据类型例如 float32、float16、bfloat16、fp8 等等,并讨论了各种数据类型下的张量会占用多少存储空间。

在计算核算小节我们学习了 tensor 的各种操作,需要注意的是张量的很多操作实际上并不会创建一个新的张量,它们只会创建一个不同的视图,例如 transposeslicing 等等操作。

在计算核算小节还重点讨论了张量操作的计算成本,它通常用浮点运算次数 FLOPs 来衡量,一个经验法则是:对于矩阵乘法而言,浮点运算次数是维度乘积的两倍。之后我们还学习了梯度的 FLOPs 计算,我们以一个特定的线性模型为例进行了讲解,得到的结论是前向传播的计算量是参数数量的 2 倍,反向传播的计算量是参数数量的 4 倍,加起来总共是 6 倍。

模型小节我们完成了张量一直到训练循环的整个过程的构建,讲解了优化器,并讨论了其需要的内存和计算资源,最后简单聊了聊混合精度训练。

整个讲解非常通俗易懂,大家感兴趣的可以看看

下一讲我们将深入探讨关于语言模型架构和训练的一些细节,敬请期待🤗

参考

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐