深入探索Mamba模型架构与应用 - 商品搜索 - 京东

 DeepSeek大模型高性能核心技术与多模态融合开发 - 商品搜索 - 京东

本节将演示一个使用Mamba2模型完成文本生成任务的示例。在这个过程中,我们将充分利用已有的数据集来训练和优化Mamba2模型,以实现高质量的文本生成效果。

12.2.1  文本生成Mamba2模型的完整实现

类似于前面使用Mamba完成文本生成任务,我们这里将使用Mamba2作为主干网格设计文本生成模型。在具体应用上,使用Mamba2模型直接替代原有的Mamba模型即可。完整代码如下:

class Mamba2LMHeadModel(nn.Module):
    def _ _init_ _(self, args: Mamba2Config, device: Device = None):
        super()._ _init_ _()
        self.args = args
        self.device = device

        self.backbone = nn.ModuleDict(
            dict(
                embedding=nn.Embedding(args.vocab_size, args.d_model, device=device),
                layers=nn.ModuleList(
                    [
                        nn.ModuleDict(
                            dict(
                                mixer=Mamba2(args, device=device),
                                norm=RMSNorm(args.d_model, device=device),
                            )
                        )
                        for _ in range(args.n_layer)
                    ]
                ),
                norm_f=RMSNorm(args.d_model, device=device),
            )
        )
        self.lm_head = nn.Linear(
            args.d_model, args.vocab_size, bias=False, device=device
        )
        self.lm_head.weight = self.backbone.embedding.weight

    def forward(
        self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None
    ) -> tuple[LongTensor, list[InferenceCache]]:

        seqlen = input_ids.shape[1]

        if h is None:
            h = [None for _ in range(self.args.n_layer)]

        x = self.backbone.embedding(input_ids)
        for i, layer in enumerate(self.backbone.layers):
            y, h[i] = layer.mixer(layer.norm(x), h[i])
            x = y + x

        x = self.backbone.norm_f(x)
        logits = self.lm_head(x)
        return logits[:, :seqlen], cast(list[InferenceCache], h)

可以看到,这里的核心在于使用我们之前完成的Mamba2模型作为特征提取的主干网络,以抽取和计算特征。至于其他部分,如输出分类层和返回值,读者可以参考Mamba模型的实现。

12.2.2  基于Mamba2的文本生成

最后,我们将完成基于Mamba2的文本生成实战任务。对于这部分内容,读者可以参考我们在实现Mamba文本生成时所准备的训练框架。完整代码如下:

from model import Mamba2,Mamba2Config
import math
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader

device = "cuda"
mamba_cfg = Mamba2Config(d_model=384)
mamba_cfg.chunk_size = 4
model = mamba_model  = Mamba2(mamba_cfg,device=device)
model.to(device)
save_path = "./saver/mamba_generator.pth"
model.load_state_dict(torch.load(save_path),strict=False)

BATCH_SIZE = 192
seq_len = 64
import get_data_emotion
#import get_data_emotion_2 as get_data_emotion
train_dataset = get_data_emotion.TextSamplerDataset(get_ data_emotion.token_list,seq_len=seq_len)
train_loader = (DataLoader(train_dataset, batch_size=BATCH_SIZE,shuffle=True))

optimizer = torch.optim.AdamW(model.parameters(), lr = 2e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max = 1200,eta_min=2e-7,last_epoch=-1)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(48):
    pbar = tqdm(train_loader,total=len(train_loader))
    for token_inp,token_tgt in pbar:
        token_inp = token_inp.to(device)
        token_tgt = token_tgt.to(device)
        logits = model(token_inp)
        loss = criterion(logits.view(-1, logits.size(-1)), token_tgt.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()  # 执行优化器
        pbar.set_description(f"epoch:{epoch +1}, train_loss:{loss.item():.5f}, lr:{lr_scheduler.get_last_lr()[0]*1000:.5f}")

    torch.save(model.state_dict(), save_path)

简单给一下代码,具体生成结果,读者可以自行验证。

本书节选自《深入探索Mamba模型架构与应用》,获出版社和作者授权发布。

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐