基于Mamba2的文本生成实战
本节将演示一个使用Mamba2模型完成文本生成任务的示例。在这个过程中,我们将充分利用已有的数据集来训练和优化Mamba2模型,以实现高质量的文本生成效果。
·
DeepSeek大模型高性能核心技术与多模态融合开发 - 商品搜索 - 京东
本节将演示一个使用Mamba2模型完成文本生成任务的示例。在这个过程中,我们将充分利用已有的数据集来训练和优化Mamba2模型,以实现高质量的文本生成效果。
12.2.1 文本生成Mamba2模型的完整实现
类似于前面使用Mamba完成文本生成任务,我们这里将使用Mamba2作为主干网格设计文本生成模型。在具体应用上,使用Mamba2模型直接替代原有的Mamba模型即可。完整代码如下:
class Mamba2LMHeadModel(nn.Module):
def _ _init_ _(self, args: Mamba2Config, device: Device = None):
super()._ _init_ _()
self.args = args
self.device = device
self.backbone = nn.ModuleDict(
dict(
embedding=nn.Embedding(args.vocab_size, args.d_model, device=device),
layers=nn.ModuleList(
[
nn.ModuleDict(
dict(
mixer=Mamba2(args, device=device),
norm=RMSNorm(args.d_model, device=device),
)
)
for _ in range(args.n_layer)
]
),
norm_f=RMSNorm(args.d_model, device=device),
)
)
self.lm_head = nn.Linear(
args.d_model, args.vocab_size, bias=False, device=device
)
self.lm_head.weight = self.backbone.embedding.weight
def forward(
self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None
) -> tuple[LongTensor, list[InferenceCache]]:
seqlen = input_ids.shape[1]
if h is None:
h = [None for _ in range(self.args.n_layer)]
x = self.backbone.embedding(input_ids)
for i, layer in enumerate(self.backbone.layers):
y, h[i] = layer.mixer(layer.norm(x), h[i])
x = y + x
x = self.backbone.norm_f(x)
logits = self.lm_head(x)
return logits[:, :seqlen], cast(list[InferenceCache], h)
可以看到,这里的核心在于使用我们之前完成的Mamba2模型作为特征提取的主干网络,以抽取和计算特征。至于其他部分,如输出分类层和返回值,读者可以参考Mamba模型的实现。
12.2.2 基于Mamba2的文本生成
最后,我们将完成基于Mamba2的文本生成实战任务。对于这部分内容,读者可以参考我们在实现Mamba文本生成时所准备的训练框架。完整代码如下:
from model import Mamba2,Mamba2Config
import math
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
device = "cuda"
mamba_cfg = Mamba2Config(d_model=384)
mamba_cfg.chunk_size = 4
model = mamba_model = Mamba2(mamba_cfg,device=device)
model.to(device)
save_path = "./saver/mamba_generator.pth"
model.load_state_dict(torch.load(save_path),strict=False)
BATCH_SIZE = 192
seq_len = 64
import get_data_emotion
#import get_data_emotion_2 as get_data_emotion
train_dataset = get_data_emotion.TextSamplerDataset(get_ data_emotion.token_list,seq_len=seq_len)
train_loader = (DataLoader(train_dataset, batch_size=BATCH_SIZE,shuffle=True))
optimizer = torch.optim.AdamW(model.parameters(), lr = 2e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max = 1200,eta_min=2e-7,last_epoch=-1)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(48):
pbar = tqdm(train_loader,total=len(train_loader))
for token_inp,token_tgt in pbar:
token_inp = token_inp.to(device)
token_tgt = token_tgt.to(device)
logits = model(token_inp)
loss = criterion(logits.view(-1, logits.size(-1)), token_tgt.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step() # 执行优化器
pbar.set_description(f"epoch:{epoch +1}, train_loss:{loss.item():.5f}, lr:{lr_scheduler.get_last_lr()[0]*1000:.5f}")
torch.save(model.state_dict(), save_path)
简单给一下代码,具体生成结果,读者可以自行验证。
本书节选自《深入探索Mamba模型架构与应用》,获出版社和作者授权发布。

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)