从零实现强化学习DPO(SimPO)训练代码
实验的话我们就只对DPO进行测试了。使用B站的Index-1___9B-Chat模型在HF上随便找了一个关于unsloth的数据集进行了一下测试,数据集如下:是一个关于unsloth的问答,jsonl格式,每行都有prompt、chosen、rejected字段。训练过程结果如下:loss下降还是可以的,测试了几个用例输出,相较之前回答会不一样,有一点点变好吧。因为我们只有50条数据,且DPO之类
强化学习已经越来越被大家所熟知了,从最开始的PPO到现在的各种DPO及其相关变体。对于SFT,理论上讲是让模型的输出更符合规范,而对于强化学习来说,应该就是让模型知道什么是不可以输出的。
现在我们从零开始,进行一下强化学习DPO的全过程代码实现。
1、加载模型。
这一步就是需要再huggingface加载成熟的模型了,后续使用过程中会用到输入给模型tensor,模型输出logits的过程,下面我们开始加载模型并简单的示例一下输出logits的过程。我们以B站的1.9B小模型为例开始:
了解过DPO的原理会知道DPO训练时会有两个模型,分别是policy model和ref model,前者就是我们要进行训练的模型,而后者就是未训练前的模型,用以规范模型的输出。
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
# 1、加载模型与tokenizer
model_path = 'IndexTeam/Index-1___9B-Chat'
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16).to('cuda:3')
ref_model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16).to('cuda:3')
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
在DPO的过程中我们需要将输入给模型,使其输出logits,示例如下:
我们随便输入一个2批次的tensor给模型:
data = torch.tensor([[1,200,33,333],[1,2,3,4]]).to("cuda:3")
out = model(data)
print(out.logits.shape)
其输出为一个CausalLMOutputWithPast特有类型,其中的keys包括[‘logits’, ‘past_key_values’],我们需要提取出logits即可,上述代码的输出如下,对应的为batch、sequence_len、vocab_size:
torch.Size([2, 4, 65029])
2、处理数据
我们知道DPO的数据正常会有三个字段,如下:
- prompt
- chosen
- rejected
我们本次使用的数据类似如下,jsonl格式,每一行都有对应的prompt、chosen、rejected:
{"prompt":"What is the primary goal of Unsloth for LLM fine-tuning?","chosen":"The primary goal of Unsloth for LLM fine-tuning is to accelerate the process, achieving a 2x speedup while maintaining 0% accuracy degradation compared to normal QLoRA.","rejected":"The primary goal of Unsloth for LLM fine-tuning is to slow down the process and increase memory usage."}
其中输入给模型的分别是prompt+chosen作为一个chosen的完整字段,prompt+rejected作为一个rejected的完整字段。下面我们开始进行数据的处理,首先进行一下我们dataset的编写:
class RlhfDataset(Dataset):
def __init__(self, file_path, tokenizer):
with open(file_path, "r", encoding="utf-8") as file:
data_list = file.readlines()
self.data_list = data_list
self.tokenizer = tokenizer
def __getitem__(self, item):
data = self.data_list[item]
data = json.loads(data)
prompt = data['prompt']
chosen = data['chosen']
rejected = data['rejected']
chosen_full_text = f"{prompt}\n\n### Response:\n{chosen}"
rejected_full_text = f"{prompt}\n\n### Response:\n{rejected}"
prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
chosen_full_tokens = self.tokenizer.encode(chosen_full_text, add_special_tokens=False)
rejected_full_tokens = self.tokenizer.encode(rejected_full_text, add_special_tokens=False)
input = {
"prompt": prompt_tokens,
"chosen": chosen_full_tokens,
"rejected": rejected_full_tokens,
}
return input
def __len__(self):
return len(self.data_list)
这里我就直接内置了一个chat template。
然后再进行collate的编写:
def data_collate(batch, pad_token_id, device, max_length=None, if_mask_prompt=True):
batch_data = {
"prompt": [],
"chosen": [],
"rejected": [],
"rejected_mask": [],
"chosen_mask": []
}
# 判断长度及padding
max_length_common = 0
for key in ["chosen", "rejected"]:
current_max = max(len(item[key]) for item in batch)
max_length_common = max(max_length_common, current_max)
# 转为torch tensor并padding,决定是否对prompt进行mask
for item in batch:
prompt = torch.tensor(item['prompt'])
batch_data['prompt'].append(prompt)
for key in ["chosen", "rejected"]:
out = item[key]
out_padding = out + [pad_token_id] * (max_length_common - len(out))
mask = torch.ones(len(out_padding)).bool()
# padding部分的mask设置为 IGNORE_INDEX
mask[len(out):] = IGNORE_INDEX
if if_mask_prompt:
mask[:prompt.shape[0] + 2] = IGNORE_INDEX
batch_data[key].append(torch.tensor(out_padding))
batch_data[f"{key}_mask"].append(mask)
# 进行最大长度截断
for key in ["chosen", "rejected", "chosen_mask", "rejected_mask"]:
tensor_stack = torch.stack(batch_data[key])
if max_length is not None:
tensor_stack = tensor_stack[:, :max_length]
# 将tensor移到对应的device
batch_data[key] = tensor_stack.to(device)
return batch_data
最后我们就可以划分数据了:
# 加载数据
data_file = './unsloth_dpo.jsonl'
dataset = RlhfDataset(data_file, tokenizer)
# 划分训练集验证集
train_size = int(len(dataset) * 0.85) # 85% for training
val_size = len(dataset) - train_size # Remaining for validation
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# 设置相关参数
batch_size = 4
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
collate_fn=customized_collate_fn,
shuffle=True,
drop_last=True
)
val_loader = DataLoader(
val_dataset,
batch_size=1,
collate_fn=customized_collate_fn,
shuffle=False,
drop_last=False
)
3、开始编写DPO、SimPO的损失函数
损失函数基本属于最重要的一部分了。
3.1 DPO loss、SimPo loss计算函数编写
这里我们假设已经知道模型输出的Log probabilities,那么两个loss的计算就可以如下表示。
关于DPO loss公式可见下图:
而对于SimPO,其实是很简单的变体,相较于DPO加了一个gamma参数,且去掉了ref模型制约,并且因为公式中有|y|,也就是说Log probabilities要求平均操作了。公式如下
最后计算loss时的Log probabilities有两种方式,一种是对Log probabilities在每个单词上求平均(SimPO中),另一中就是对每个句子中单词的Log probabilities求和(DPO)。
TRL中的DPO是进行Log probabilities的求和,但我们这里就改一下进行求平均。这里我们先写loss的代码,两种方式的Log probabilities代码我们下一节实现。
class DPOLoss(nn.Module):
"""
DPO Loss
"""
def __init__(self, beta: float = 0.1) -> None:
super().__init__()
self.beta = beta
def forward(
self,
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
reference_chosen_logps: torch.Tensor,
reference_rejected_logps: torch.Tensor,
):
"""
policy_chosen_logps: 模型输出的对数概率。Shape: (batch_size,)
policy_rejected_logps: Shape: (batch_size,)
reference_chosen_logps: Shape: (batch_size,)
reference_rejected_logps: Shape: (batch_size,)
"""
policy_logps = policy_chosen_logps - policy_rejected_logps
reference_logps = reference_chosen_logps - reference_rejected_logps
logits = policy_logps - reference_logps
loss = -F.logsigmoid(self.beta * logits)
# 下面两个用于追踪训练的进度
chosen_rewards = (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = (policy_rejected_logps - reference_rejected_logps).detach()
# 对每个batch进行平均
return loss.mean(), chosen_rewards.mean(), rejected_rewards.mean()
class SimPo(nn.Module):
"""
SimPO Loss
"""
def __init__(self, beta: float = 0.1, gamma: float = 0.5) -> None:
super().__init__()
self.beta = beta
self.gamma = gamma
def forward(
self,
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
):
"""
policy_chosen_logps: 模型输出的对数概率。Shape: (batch_size,)
policy_rejected_logps: Shape: (batch_size,)
"""
logits = policy_chosen_logps - policy_rejected_logps
logits = logits - self.gamma
loss = -F.logsigmoid(self.beta * logits)
# 对每个batch进行平均(期望)
return loss.mean()
3.2 Log probabilities计算
下面需要开始实现计算模型的Log probabilities。代码如下,这里的两个输入为logits, label,其中logits为label输入给模型后输出的结果,并且当前的logits是预测label中下一个词的,故需要进行位移操作:
def compute_logprobs(logits, labels, mask=None):
"""
logits: shape (batch_size, sequence_len, vocab_size),即将label输入给模型后输出的结果
labels: shape (batch_size, sequence_len)
"""
# 需要先进行位移操作
# 去掉标签的第一个
labels = labels[:, 1:].clone()
# 去掉模型输出的最后一个
logits = logits[:, :-1, :]
logps = F.log_softmax(logits, dim=-1)
select_logprobs = torch.gather(
input=logps,
dim=-1,
index=labels.unsqueeze(1)
).squeeze(1)
if mask is not None:
mask = mask[:, 1:].clone()
# 进行掩码padding部分
select_logprobs = select_logprobs * mask
# 计算平均
average_logprobs = select_logprobs.sum(-1) / mask.sum(-1)
return average_logprobs
else:
return select_logprobs.mean(-1)
上面是已进行求平均的操作,即SimPO的实现,如果是TRL中DPO求和的操作话只需要将average_logprobs = select_logprobs.sum(-1) / mask.sum(-1)
改为average_logprobs = select_logprobs.sum(-1)
即可。
其实上面这个函数最终的输出取负数就是F.cross_entropy(logits, targets) 交叉熵的输出,只不过添加了mask操作而已,下面我们可以通过这两种不同的方式实现计算,下面是使用F.cross_entropy进行计算的代码:
def compute_logprobs_f_cross(logits, labels, mask=None):
"""
logits: shape (batch_size, sequence_len, vocab_size),即将label输入给模型后输出的结果
labels: shape (batch_size, sequence_len)
"""
# 需要先进行位移操作
# 去掉标签的第一个
labels = labels[:, 1:].clone()
# 去掉模型输出的最后一个
logits = logits[:, :-1, :].clone()
batch_size, sequence_len, vocab_size = logits.shape
cross_entropy_loss = 0
if mask is not None:
mask = mask[:, 1:].clone()
labels.masked_fill_(~mask, -100)
for i in range(batch_size):
cross_entropy_loss += F.cross_entropy(logits[i], labels[i])
else:
for i in range(batch_size):
cross_entropy_loss += F.cross_entropy(logits[i], labels[i])
cross_entropy_loss /= batch_size
return cross_entropy_loss
最终我们进行一下测试即可得到如下结果:
logits = torch.tensor(
[[2.0, 1.0, 0.1, 0.4],
[0.5, 2.5, 0.3, 0.5],
[0.6, 2.5, 0.3, 0.8],
[0.5, 2.5, 0.6, 0.6]], dtype=torch.float32).unsqueeze(0)
mask = torch.tensor([[True, True, False, False]])
targets = torch.tensor([0, 1, 0, 2]).unsqueeze(0)
loss1 = -compute_logprobs(logits, targets, mask)
loss2 = compute_logprobs_f_cross(logits, targets, mask)
print(loss1,loss2)
---------------------------
tensor([1.5419]) tensor(1.5419)
要注意的是,F.cross_entropy中所计算的logits和target一般是不带batch的,例如Shape: (2, 3)与Shape: (2,),如下:
logits = torch.tensor(
[[2.0, 1.0, 0.1],
[0.5, 2.5, 0.3]]) # Shape: (2, 3)
targets = torch.tensor([0, 2]) # Shape: (2,)
- logits:形状为 (2, 3) 的张量,表示两个样本的对数概率(logits)。每个样本有三个类别的对数概率。
- targets:形状为 (2,) 的张量,表示每个样本的真实类别标签。第一个样本的真实类别是0,第二个样本的真实类别是2(大模型中这里真是类别就是其vocab_size)。
所以上述batch输入我们不能直接给进这个函数,需要每个batch进行计算。
3.2 最终batch计算loss
上面我们写好了loss的计算相关代码,下面只需要在batch层面使用上面写好的函数即可,代码如下:
def compute_batch_loss(batch, policy_model, reference_model, beta):
# 决定使用哪个loss
# loss_fn = SimPo(beta, 0.5) SimPO loss
loss_fn = DPOLoss(beta) # DPO loss
policy_chosen_logps = compute_logprobs(
logits=policy_model(batch["chosen"]).logits,
labels=batch["chosen"],
mask=batch["chosen_mask"]
)
policy_rejected_logps = compute_logprobs(
logits=policy_model(batch["rejected"]).logits,
labels=batch["rejected"],
mask=batch["rejected_mask"]
)
reference_chosen_logps = compute_logprobs(
logits=reference_model(batch['chosen']).logits,
labels=batch['chosen'],
mask=batch["chosen_mask"]
)
reference_rejected_logps = compute_logprobs(
logits=reference_model(batch['rejected']).logits,
labels=batch['rejected'],
mask=batch["rejected_mask"]
)
loss, chosen_rewards, rejected_rewards = loss_fn(
policy_chosen_logps=policy_chosen_logps,
policy_rejected_logps=policy_rejected_logps,
reference_chosen_logps=reference_chosen_logps,
reference_rejected_logps=reference_rejected_logps,
)
# SimPO使用如下
# loss = loss_fn(
# policy_chosen_logps=policy_chosen_logps,
# policy_rejected_logps=policy_rejected_logps,
# )
# return loss
return loss, chosen_rewards, rejected_rewards
4、开始训练
下面我们开始训练脚本的编写,不用Trainer确实比较麻烦,需要自己手动的epoch循环之类的,不过如果之前做过CV相关的话应该就不会陌生了。
下面是我们的训练函数
def train_model(
policy_model, reference_model, train_loader, val_loader,
optimizer, num_epochs, beta,
eval_freq, eval_iter):
tracking = {
"train_losses": [],
"train_chosen_rewards": [],
"train_rejected_rewards": [],
"val_losses": [],
"val_chosen_rewards": [],
"val_rejected_rewards": [],
"tokens_seen": []
}
tokens_seen, global_step = 0, -1
# 训练
for epoch in range(num_epochs):
# policy 模型需要训练
policy_model.train()
for idx, batch in enumerate(train_loader):
optimizer.zero_grad()
loss, chosen_rewards, rejected_rewards = compute_batch_loss(
batch=batch,
policy_model=policy_model,
reference_model=reference_model,
beta=beta
)
loss.backward()
optimizer.step()
global_step += 1
tokens_seen += batch["chosen"].numel()
# 验证
if global_step % eval_freq == 0:
res = evaluate_loss_dataloader(
policy_model=policy_model,
reference_model=reference_model,
train_loader=train_loader,
val_loader=val_loader,
beta=beta,
eval_iter=eval_iter
)
tracking["train_losses"].append(res["train_loss"])
tracking["train_chosen_rewards"].append(res["train_chosen_reward"])
tracking["train_rejected_rewards"].append(res["train_rejected_reward"])
tracking["val_losses"].append(res["val_loss"])
tracking["val_chosen_rewards"].append(res["val_chosen_reward"])
tracking["val_rejected_rewards"].append(res["val_rejected_reward"])
tracking["tokens_seen"].append(tokens_seen)
train_reward_margin = res["train_chosen_reward"] - res["train_rejected_reward"]
val_reward_margin = res["val_chosen_reward"] - res["val_rejected_reward"]
print(
f"Ep {epoch + 1} (Step {global_step:06d}): "
f"Train loss {res['train_loss']:.3f}, Val loss {res['val_loss']:.3f}, "
f"Train reward margins {train_reward_margin:.3f}, "
f"Val reward margins {val_reward_margin:.3f}"
)
return tracking
训练函数已经写好了,我们开始训练:
def main():
torch.manual_seed(42)
start_time = time.time()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
num_epochs = 3
tracking = train_model(
policy_model=model,
reference_model=ref_model,
train_loader=train_loader,
val_loader=val_loader,
optimizer=optimizer,
num_epochs=num_epochs,
beta=0.1, # value between 0.1 and 0.5
eval_freq=2,
eval_iter=2
)
end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")
训练脚本中我就没有设置保存模型的代码了,大家可以自行设置或者在jupter中运行后接着进行推理测试。
5、总结:实验结果
5.1 实验
实验的话我们就只对DPO进行测试了。
使用B站的Index-1___9B-Chat模型在HF上随便找了一个关于unsloth的数据集进行了一下测试,数据集如下:
是一个关于unsloth的问答,jsonl格式,每行都有prompt、chosen、rejected字段。
训练过程结果如下:
loss下降还是可以的,测试了几个用例输出,相较之前回答会不一样,有一点点变好吧。因为我们只有50条数据,且DPO之类的强化学习主要是减少bad的输出,而不是学习新知识,故提升不大也在合理范围内。
5.2 总结
上述只是我们简单的进行了DPO SimPO的loss实现及训练代码编写,只是一个demo示例,并没有增加分布式训练、模型chat template适配等等。仅供学习原理使用吧。本文的全部代码已保存至github下:DPO_example
如果想要使用DPO或者Simpo、CPO等强化学习方法真正训练的话,
可以使用本项目中构建的强化学习框架,支持deepspeed的单机多卡Lora、Dora、Qlora、全量参数训练,并自动适配模型的chat template:RLHF训练

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)