一文读懂循环神经网络—门控循环单元
重置门负责筛选历史信息,帮助模型关注 "当前相关" 的历史内容。更新门负责平衡新旧信息,帮助模型在 "记忆" 和 "遗忘" 之间找到平衡点。
·
目录
候选隐藏状态(Candidate Hidden State)
重置门(Reset Gate)
定义
重置门决定了如何将新的输入信息与之前的隐藏状态相结合。它可以 "重置" 历史隐藏状态的部分信息,允许模型有选择地遗忘过去。
作用
- 捕获短期依赖:通过控制过去信息的保留程度,帮助模型关注最近的输入。
- 防止梯度消失:允许梯度在需要时更有效地流动。
数学公式
是重置门输出
是上一时刻隐藏状态
为当前时刻输入
为权重矩阵
表示 sigmoid 函数
表示拼接操作
为偏置项
- 功能:决定如何 "重置" 历史隐藏状态,控制上一时刻的隐藏状态
对当前候选状态的影响程度。
- 输出范围:
,其中 0 表示完全忽略历史,1 表示保留全部历史。
更新门(Update Gate)
定义
更新门决定了新的隐藏状态中有多少来自过去的隐藏状态,以及多少来自当前输入的新信息。它类似于 LSTM 中的遗忘门和输入门的组合。
作用
- 捕获长期依赖:通过控制信息的更新程度,允许模型保留长期信息。
- 减少冗余计算:只有当更新门指示需要更新时,模型才会处理新输入。
数学公式
是更新门的输出(范围在 0 到 1 之间)- 其他符号含义与重置门相同
- 功能:决定新的隐藏状态中,有多少来自候选状态
,多少来自历史状态
。
- 输出范围:
,其中 0 表示完全使用新信息,1 表示完全保留历史信息。
隐藏状态(Hidden State)
定义
隐藏状态 是 RNN 在时间步 t 的内部表示,它融合了 历史信息 和 当前输入,并作为后续时间步的上下文。
数学表示
在标准 RNN 中:
:当前输入
:上一时刻的隐藏状态
:权重矩阵
:激活函数(如 tanh 或 ReLU)
核心作用
- 记忆功能:通过
传递历史信息,使模型能够处理序列中的长期依赖。
- 上下文整合:将历史信息与当前输入结合,形成对序列的动态理解。
候选隐藏状态(Candidate Hidden State)
定义
候选隐藏状态 是 临时计算的中间状态,用于生成下一时刻的实际隐藏状态
。它在门控循环单元(如 LSTM、GRU)中尤为重要。
数学表示(以 GRU 为例)
:重置门输出
:元素级乘法
:激活函数,将输出约束在 \([-1, 1]\)
核心作用
- 信息筛选:通过重置门
选择性地保留历史信息,避免无关信息干扰。
- 生成新状态:
作为 "候选",需要经过更新门的调控才能成为最终的隐藏状态
。
隐藏状态 vs 候选隐藏状态
| 对比项 | 隐藏状态 ( |
候选隐藏状态 ( |
|---|---|---|
| 角色 | 最终的上下文表示,传递到下一时刻 | 生成新隐藏状态的中间计算结果 |
| 是否门控 | 是(通过更新门 |
是(通过重置门 |
| 信息来源 | 整合了历史状态 |
基于当前输入 |
| 范围 | 由更新门 |
由 |
直观理解
详解
-
候选隐藏状态
: 可以看作是 "建议更新内容",它根据当前输入和部分历史信息提出一个 "候选",但需要经过更新门的批准才能生效。
-
隐藏状态
: 可以看作是 "历史记忆 + 新信息的融合",它通过更新门权衡历史与当前的重要性,决定最终保留哪些信息。
- 重置门:类似于 "遗忘开关",决定是否忽略历史隐藏状态。当
≈ 0时,模型几乎完全忽略历史,专注于当前输入。 - 更新门:类似于 "记忆开关",决定是否保留历史隐藏状态。当
时,模型主要使用新信息;当≈ 0
时,主要保留历史信息。≈ 1
为什么需要候选状态?
门控机制(如 GRU 的重置门和更新门)的核心目的是 选择性地记忆和遗忘:
- 重置门 通过
控制历史信息的哪些部分参与生成 \(\tilde{h}_t\),帮助模型关注短期信息。
- 更新门 通过
控制对
的影响程度,帮助模型保留长期信息。
这种设计使 RNN 能够有效处理 梯度消失 和 长期依赖 问题。
可视化流程(GRU 为例)
plaintext
输入序列: x_1 → x_2 → x_3 → ... → x_t
1. 计算重置门:
r_t = σ(W_r·[h_{t-1}, x_t] + b_r)
2. 计算候选隐藏状态:
h̃_t = tanh(W·[r_t⊙h_{t-1}, x_t] + b) # 基于部分历史和当前输入
3. 计算更新门:
z_t = σ(W_z·[h_{t-1}, x_t] + b_z)
4. 更新隐藏状态:
h_t = (1-z_t)⊙h̃_t + z_t⊙h_{t-1} # 融合候选状态和历史状态
完整代码
"""
文件名: 9.1
作者: 墨尘
日期: 2025/7/15
项目名: dl_env
备注: 基于GRU(门控循环单元)的字符级文本生成模型,以《时间机器》文本为训练数据
"""
# 基础工具库
import collections # 用于统计词频
import random # 随机抽样
import re # 文本清洗(正则表达式)
import requests # 下载数据集
from pathlib import Path # 文件路径处理
from d2l import torch as d2l # 深度学习工具库
import math # 数学运算
import torch # PyTorch框架
from torch import nn # 神经网络模块
from torch.nn import functional as F # 函数式API
# 图像显示相关库(解决中文和符号显示问题)
import matplotlib.pyplot as plt
import matplotlib.text as text
# -------------------------- 核心解决方案:解决文本显示问题 --------------------------
def replace_minus(s):
"""
解决Matplotlib中Unicode减号(U+2212)显示为方块的问题
原理:将特殊减号替换为普通ASCII减号('-')
"""
if isinstance(s, str): # 仅处理字符串
return s.replace('\u2212', '-') # 替换特殊减号
return s # 非字符串直接返回
# 重写matplotlib的Text类的set_text方法,全局生效
original_set_text = text.Text.set_text # 保存原始方法
def new_set_text(self, s):
s = replace_minus(s) # 先处理减号
return original_set_text(self, s) # 调用原始方法设置文本
text.Text.set_text = new_set_text # 应用重写后的方法
# -------------------------- 字体配置(确保中文和数学符号正常显示)--------------------------
plt.rcParams["font.family"] = ["SimHei"] # 设置中文字体(支持中文显示)
plt.rcParams["text.usetex"] = True # 使用LaTeX渲染文本(提升数学符号美观度)
plt.rcParams["axes.unicode_minus"] = True # 确保负号正确显示(避免方块)
plt.rcParams["mathtext.fontset"] = "cm" # 数学符号使用Computer Modern字体(LaTeX标准字体)
d2l.plt.rcParams.update(plt.rcParams) # 让d2l库的绘图工具继承上述配置
# -------------------------- 1. 读取数据 --------------------------
def read_time_machine():
"""下载并读取《时间机器》数据集,返回清洗后的文本行列表"""
data_dir = Path('./data') # 数据存储目录
data_dir.mkdir(exist_ok=True) # 目录不存在则创建
file_path = data_dir / 'timemachine.txt' # 数据集文件路径
# 检查文件是否存在,不存在则下载
if not file_path.exists():
print("开始下载时间机器数据集...")
# 从d2l官方地址下载文本
response = requests.get('http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt')
# 写入文件(utf-8编码)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(response.text)
print(f"数据集下载完成,保存至: {file_path}")
# 读取文件并清洗文本
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.readlines() # 按行读取
print(f"文件读取成功,总行数: {len(lines)}")
if len(lines) > 0:
print(f"第一行内容: {lines[0].strip()}") # 打印首行验证
# 清洗规则:保留字母,其他字符替换为空格,转小写,去除首尾空格
cleaned_lines = [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines if line.strip()]
print(f"清洗后有效行数: {len(cleaned_lines)}") # 清洗后非空行数量
return cleaned_lines
# -------------------------- 2. 词元化与词表构建 --------------------------
def tokenize(lines, token='char'):
"""
将文本行转换为词元列表(词元是文本的最小处理单位)
参数:
lines: 清洗后的文本行列表(如["abc def", "ghi jkl"])
token: 词元类型('char'字符级/'word'单词级)
返回:
词元列表(如字符级:[['a','b','c',' ','d','e','f'], ...])
"""
if token == 'char':
# 字符级词元化:将每行拆分为单个字符列表
return [list(line) for line in lines]
elif token == 'word':
# 单词级词元化:按空格拆分每行(需确保文本已用空格分隔单词)
return [line.split() for line in lines]
else:
raise ValueError('未知词元类型:' + token)
class Vocab:
"""词表类:实现词元与索引的双向映射,用于将文本转换为模型可处理的数字序列"""
def __init__(self, tokens, min_freq=0, reserved_tokens=None):
"""
构建词表
参数:
tokens: 词元列表(可嵌套,如[[token1, token2], [token3]])
min_freq: 最低词频阈值(低于此值的词元不加入词表)
reserved_tokens: 预留特殊词元(如分隔符、填充符等)
"""
if reserved_tokens is None:
reserved_tokens = [] # 默认为空
# 统计词频:展平嵌套列表,用Counter计数
counter = collections.Counter([token for line in tokens for token in line])
# 按词频降序排序(便于后续按频率筛选)
self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
# 初始化词表:<unk>(未知词元)固定在索引0, followed by预留词元
self.idx_to_token = ['<unk>'] + reserved_tokens
# 构建词元到索引的映射(字典)
self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}
# 按词频添加词元(过滤低频词)
for token, freq in self.token_freqs:
if freq < min_freq:
break # 低频词不加入词表
if token not in self.token_to_idx: # 避免重复添加预留词元
self.idx_to_token.append(token)
self.token_to_idx[token] = len(self.idx_to_token) - 1 # 索引为当前长度-1
def __len__(self):
"""返回词表大小(词元总数)"""
return len(self.idx_to_token)
def __getitem__(self, tokens):
"""
词元→索引(支持单个词元或词元列表)
未知词元返回<unk>的索引(0)
"""
if not isinstance(tokens, (list, tuple)):
# 单个词元:查字典,默认返回<unk>的索引
return self.token_to_idx.get(tokens, self.unk)
# 词元列表:递归转换每个词元
return [self.__getitem__(token) for token in tokens]
def to_tokens(self, indices):
"""索引→词元(支持单个索引或索引列表)"""
if not isinstance(indices, (list, tuple)):
# 单个索引:直接查列表
return self.idx_to_token[indices]
# 索引列表:递归转换每个索引
return [self.idx_to_token[index] for index in indices]
@property
def unk(self):
"""返回<unk>的索引(固定为0)"""
return 0
# -------------------------- 3. 数据迭代器(随机抽样) --------------------------
def seq_data_iter_random(corpus, batch_size, num_steps):
"""
随机抽样生成批量子序列(生成器),用于模型训练的批量输入
原理:从语料中随机截取多个长度为num_steps的子序列,组成批次
参数:
corpus: 词元索引序列(1D列表,如[1,3,5,2,...])
batch_size: 批量大小(每个批次包含的子序列数)
num_steps: 子序列长度(时间步,即模型一次处理的序列长度)
返回:
生成器,每次返回(X, Y):
X: 输入序列(batch_size, num_steps)
Y: 标签序列(batch_size, num_steps),是X右移一位的结果
"""
# 检查数据是否足够生成至少一个子序列(子序列长度+1,因Y是X右移1位)
if len(corpus) < num_steps + 1:
raise ValueError(f"语料库长度({len(corpus)})不足,需至少{num_steps+1}")
# 随机偏移起始位置(0到num_steps-1),增加数据随机性
corpus = corpus[random.randint(0, num_steps - 1):]
# 计算可生成的子序列总数:(语料长度-1) // num_steps(-1是因Y需多1个元素)
num_subseqs = (len(corpus) - 1) // num_steps
if num_subseqs < 1:
raise ValueError(f"无法生成子序列(语料库长度不足)")
# 生成所有子序列的起始索引(间隔为num_steps)
initial_indices = list(range(0, num_subseqs * num_steps, num_steps))
random.shuffle(initial_indices) # 打乱起始索引,实现随机抽样
# 计算可生成的批次数:子序列总数 // 批量大小
num_batches = num_subseqs // batch_size
if num_batches < 1:
raise ValueError(f"子序列数量({num_subseqs})不足,需至少{batch_size}个")
# 生成批量数据
for i in range(0, batch_size * num_batches, batch_size):
# 当前批次的起始索引(从打乱的索引中取batch_size个)
indices = initial_indices[i: i + batch_size]
# 输入序列X:每个子序列从indices[j]开始,取num_steps个元素
X = [corpus[j: j + num_steps] for j in indices]
# 标签序列Y:每个子序列从indices[j]+1开始,取num_steps个元素(X右移1位)
Y = [corpus[j + 1: j + num_steps + 1] for j in indices]
# 转换为张量返回(便于模型处理)
yield torch.tensor(X), torch.tensor(Y)
# -------------------------- 4. 数据加载函数(关键修复:返回可重置的迭代器) --------------------------
def load_data_time_machine(batch_size, num_steps):
"""
加载《时间机器》数据,返回数据迭代器生成函数和词表
修复点:返回迭代器生成函数(而非一次性迭代器),确保训练时可重复生成数据
参数:
batch_size: 批量大小
num_steps: 子序列长度(时间步)
返回:
data_iter: 迭代器生成函数(调用时返回新的迭代器)
vocab: 词表对象
"""
lines = read_time_machine() # 读取清洗后的文本行
tokens = tokenize(lines, token='char') # 字符级词元化(每个字符为词元)
vocab = Vocab(tokens) # 构建词表
# 将所有词元转换为索引(展平为1D序列)
corpus = [vocab[token] for line in tokens for token in line]
print(f"语料库长度: {len(corpus)}(词元索引总数)")
# 定义迭代器生成函数:每次调用生成新的随机抽样迭代器
def data_iter():
return seq_data_iter_random(corpus, batch_size, num_steps)
return data_iter, vocab # 返回生成函数和词表
# -------------------------- 5. GRU模型核心实现 --------------------------
def get_params(vocab_size, num_hiddens, device):
"""
初始化GRU模型参数
包含:更新门、重置门、候选隐状态、输出层的权重和偏置
参数:
vocab_size: 词表大小(输入/输出维度)
num_hiddens: 隐藏层维度(隐状态维度)
device: 计算设备(CPU/GPU)
返回:
参数列表:[W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
"""
num_inputs = num_outputs = vocab_size # 输入/输出维度=词表大小
# 正态分布初始化参数(均值0,标准差0.01)
def normal(shape):
return torch.randn(size=shape, device=device) * 0.01
# 生成三组参数(权重1、权重2、偏置),用于门控机制
def three():
return (normal((num_inputs, num_hiddens)), # 输入→隐藏权重
normal((num_hiddens, num_hiddens)), # 隐藏→隐藏权重
torch.zeros(num_hiddens, device=device)) # 偏置(初始化为0)
# 更新门(Update Gate)参数:W_xz(输入→更新门)、W_hz(隐藏→更新门)、b_z(偏置)
W_xz, W_hz, b_z = three()
# 重置门(Reset Gate)参数:W_xr(输入→重置门)、W_hr(隐藏→重置门)、b_r(偏置)
W_xr, W_hr, b_r = three()
# 候选隐状态(Candidate Hidden State)参数:W_xh(输入→候选隐状态)、W_hh(隐藏→候选隐状态)、b_h(偏置)
W_xh, W_hh, b_h = three()
# 输出层参数:W_hq(隐藏→输出)、b_q(偏置)
W_hq = normal((num_hiddens, num_outputs))
b_q = torch.zeros(num_outputs, device=device)
# 所有参数附加梯度(允许反向传播更新)
params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
for param in params:
param.requires_grad_(True)
return params
def init_gru_state(batch_size, num_hiddens, device):
"""
初始化GRU的隐藏状态(全零向量)
返回元组形式,便于扩展(如LSTM有两个状态)
参数:
batch_size: 批量大小
num_hiddens: 隐藏层维度
device: 计算设备
返回:
隐藏状态元组:(H,),其中H形状为(batch_size, num_hiddens)
"""
return (torch.zeros((batch_size, num_hiddens), device=device), )
def gru(inputs, state, params):
"""
GRU前向传播(逐时间步计算)
参数:
inputs: 输入序列(num_steps, batch_size, vocab_size),已转换为one-hot编码
state: 初始隐藏状态(batch_size, num_hiddens)
params: GRU参数列表(见get_params)
返回:
outputs: 所有时间步的输出(num_steps*batch_size, vocab_size)
state: 最终隐藏状态(batch_size, num_hiddens)
"""
# 解析参数
W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
H, = state # 初始隐藏状态(从元组中取出)
outputs = [] # 存储每个时间步的输出
# 逐时间步计算
for X in inputs: # X形状:(batch_size, vocab_size)(当前时间步的输入)
# 1. 计算更新门 Z_t = σ(X_t·W_xz + H_{t-1}·W_hz + b_z)
Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
# 2. 计算重置门 R_t = σ(X_t·W_xr + H_{t-1}·W_hr + b_r)
R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
# 3. 计算候选隐状态 Ĥ_t = tanh(X_t·W_xh + (R_t ⊙ H_{t-1})·W_hh + b_h)
H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
# 4. 计算最终隐状态 H_t = Z_t ⊙ H_{t-1} + (1-Z_t) ⊙ Ĥ_t
H = Z * H + (1 - Z) * H_tilda
# 5. 计算输出 Y_t = H_t·W_hq + b_q
Y = H @ W_hq + b_q
outputs.append(Y) # 保存当前时间步的输出
# 拼接所有时间步的输出(形状:(num_steps*batch_size, vocab_size)),返回输出和最终状态
return torch.cat(outputs, dim=0), (H,)
# -------------------------- 6. RNN模型包装类 --------------------------
class RNNModelScratch: #@save
"""从零实现的RNN模型包装类,统一模型调用接口"""
def __init__(self, vocab_size, num_hiddens, device,
get_params, init_state, forward_fn):
"""
参数:
vocab_size: 词表大小(输入/输出维度)
num_hiddens: 隐藏层维度
device: 计算设备
get_params: 参数初始化函数(如get_params)
init_state: 状态初始化函数(如init_gru_state)
forward_fn: 前向传播函数(如gru)
"""
self.vocab_size, self.num_hiddens = vocab_size, num_hiddens
self.params = get_params(vocab_size, num_hiddens, device) # 模型参数
self.init_state, self.forward_fn = init_state, forward_fn # 状态初始化和前向传播函数
def __call__(self, X, state):
"""
模型调用接口(前向传播入口)
参数:
X: 输入序列(batch_size, num_steps),元素为词元索引
state: 初始隐藏状态
返回:
y_hat: 输出(num_steps*batch_size, vocab_size)
state: 最终隐藏状态
"""
# 处理输入:
# 1. X.T:转置为(num_steps, batch_size)(便于逐时间步处理)
# 2. F.one_hot:转换为one-hot编码(num_steps, batch_size, vocab_size)
# 3. type(torch.float32):转换为浮点型(适配后续矩阵运算)
X = F.one_hot(X.T, self.vocab_size).type(torch.float32)
# 调用前向传播函数
return self.forward_fn(X, state, self.params)
def begin_state(self, batch_size, device):
"""获取初始隐藏状态(调用初始化函数)"""
return self.init_state(batch_size, self.num_hiddens, device)
# -------------------------- 7. 预测函数(文本生成) --------------------------
def predict_ch8(prefix, num_preds, net, vocab, device): #@save
"""
根据前缀生成后续字符(文本生成)
参数:
prefix: 前缀字符串(如"time traveller")
num_preds: 要生成的字符数
net: 训练好的GRU模型
vocab: 词表
device: 计算设备
返回:
生成的字符串(前缀+预测字符)
"""
# 初始化状态(批量大小为1,因仅生成一条序列)
state = net.begin_state(batch_size=1, device=device)
# 记录输出索引:初始为前缀首字符的索引
outputs = [vocab[prefix[0]]]
# 辅助函数:获取当前输入(最后一个输出的索引,形状(1,1))
def get_input():
return torch.tensor([outputs[-1]], device=device).reshape((1, 1))
# 预热期:用前缀更新模型状态(不生成新字符,仅让模型"记住"前缀)
for y in prefix[1:]:
_, state = net(get_input(), state) # 前向传播,更新状态(忽略输出)
outputs.append(vocab[y]) # 记录前缀字符的索引
# 预测期:生成num_preds个字符
for _ in range(num_preds):
y, state = net(get_input(), state) # 前向传播,获取输出和新状态
# 取概率最大的字符索引(贪婪采样)
outputs.append(int(y.argmax(dim=1).reshape(1)))
# 将索引转换为字符,拼接成字符串返回
return ''.join([vocab.idx_to_token[i] for i in outputs])
# -------------------------- 8. 梯度裁剪(防止梯度爆炸) --------------------------
def grad_clipping(net, theta): #@save
"""
裁剪梯度(将梯度L2范数限制在theta内),防止梯度爆炸
参数:
net: 模型(自定义模型或nn.Module)
theta: 梯度阈值
"""
# 获取需要梯度更新的参数
if isinstance(net, nn.Module):
# 若为PyTorch官方Module,直接取parameters
params = [p for p in net.parameters() if p.requires_grad]
else:
# 若为自定义模型,取params属性
params = net.params
# 计算所有参数梯度的L2范数
norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
if norm > theta: # 若范数超过阈值,按比例裁剪
for param in params:
param.grad[:] *= theta / norm
# -------------------------- 9. 训练函数 --------------------------
def train_epoch_ch8(net, train_iter_fn, loss, updater, device, use_random_iter):
"""
训练一个周期(单轮遍历数据集)
参数:
net: GRU模型
train_iter_fn: 迭代器生成函数(调用后返回新迭代器)
loss: 损失函数(如CrossEntropyLoss)
updater: 优化器(如SGD)
device: 计算设备
use_random_iter: 是否使用随机抽样(影响状态处理)
返回:
ppl: 困惑度(perplexity,衡量语言模型性能,越低越好)
speed: 训练速度(词元/秒)
"""
state, timer = None, d2l.Timer() # 初始化状态和计时器
metric = d2l.Accumulator(2) # 累加器:(总损失, 总词元数)
batches_processed = 0 # 记录处理的批次数量
# 关键修复:每次训练都通过函数生成新的迭代器(避免迭代器被提前消费)
train_iter = train_iter_fn()
# 遍历批量数据
for X, Y in train_iter:
batches_processed += 1
# 初始化状态:
# - 首次迭代时需初始化
# - 随机抽样时,每个批次的状态独立,需重新初始化
if state is None or use_random_iter:
state = net.begin_state(batch_size=X.shape[0], device=device)
else:
# 非随机抽样时,分离状态(切断梯度回流到之前的批次,避免梯度计算依赖过长)
if isinstance(net, nn.Module) and not isinstance(state, tuple):
state.detach_() # 单个状态直接detach
else:
for s in state: # 多个状态(如LSTM)逐个detach
s.detach_()
# 处理标签:
# Y.T.reshape(-1):转置后展平为(num_steps*batch_size,)(与输出形状匹配)
y = Y.T.reshape(-1)
# 将输入和标签移到目标设备
X, y = X.to(device), y.to(device)
# 前向传播:获取输出和新状态
y_hat, state = net(X, state)
# 计算损失(mean()是因损失函数可能返回每个样本的损失)
l = loss(y_hat, y.long()).mean()
# 反向传播与参数更新:
if isinstance(updater, torch.optim.Optimizer):
# 若为PyTorch优化器(如SGD)
updater.zero_grad() # 清零梯度
l.backward() # 反向传播
grad_clipping(net, 1) # 裁剪梯度(阈值1)
updater.step() # 更新参数
else:
# 若为自定义优化器
l.backward()
grad_clipping(net, 1)
updater(batch_size=1) # 假设批量大小为1的更新
# 累加总损失和总词元数(用于计算平均损失)
metric.add(l * y.numel(), y.numel())
# 检查是否有批次被处理(避免空迭代)
if batches_processed == 0:
print("警告:没有处理任何训练批次!")
return float('inf'), 0
# 计算困惑度(perplexity = exp(平均损失))和训练速度(词元/秒)
return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()
def train_ch8(net, train_iter_fn, vocab, lr, num_epochs, device, use_random_iter=False):
"""
训练模型(多周期)
参数:
net: GRU模型
train_iter_fn: 迭代器生成函数
vocab: 词表
lr: 学习率
num_epochs: 训练周期数
device: 计算设备
use_random_iter: 是否使用随机抽样(默认False)
"""
loss = nn.CrossEntropyLoss() # 交叉熵损失(适用于分类任务,此处为字符预测)
# 动画器:可视化训练过程(困惑度随周期变化)
animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',
legend=['train'], xlim=[10, num_epochs])
# 初始化优化器:
if isinstance(net, nn.Module):
# 若为PyTorch Module,使用SGD优化器
updater = torch.optim.SGD(net.parameters(), lr)
else:
# 若为自定义模型,使用d2l的sgd函数
updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)
# 定义预测函数:根据前缀"time traveller"生成50个字符
predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)
# 多周期训练
for epoch in range(num_epochs):
# 训练一个周期,返回困惑度和速度
ppl, speed = train_epoch_ch8(
net, train_iter_fn, loss, updater, device, use_random_iter)
# 每10个周期打印一次预测结果(观察生成文本质量变化)
if (epoch + 1) % 10 == 0:
print(f"epoch {epoch+1} 预测: {predict('time traveller')}")
animator.add(epoch + 1, [ppl]) # 记录困惑度
# 训练结束后输出最终结果
print(f'最终困惑度 {ppl:.1f}, 速度 {speed:.1f} 词元/秒 {device}')
print(f"time traveller 预测: {predict('time traveller')}")
print(f"traveller 预测: {predict('traveller')}")
# -------------------------- 主程序 --------------------------
if __name__ == '__main__':
# 超参数设置
batch_size, num_steps = 32, 35 # 批量大小=32,时间步=35
# 加载数据:获取迭代器生成函数和词表
train_iter, vocab = load_data_time_machine(batch_size, num_steps)
# 模型参数
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu() # 词表大小、隐藏层维度、自动选择GPU/CPU
num_epochs, lr = 500, 0.12 # 训练周期=500,学习率=0.12
# 初始化GRU模型
model = RNNModelScratch(len(vocab), num_hiddens, device, get_params,
init_gru_state, gru)
# 开始训练
train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show(block=True) # 显示训练过程的动画图(阻塞模式,确保图不闪退)
实验结果

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐


所有评论(0)