# bert
from transformers import BertModel, BertConfig

config = BertConfig.from_json_file('bert-base/config.json')
bert_model = BertModel(config, add_pooling_layer=True)
pytorch_total_params = sum(p.numel() for p in bert_model.parameters() if p.requires_grad)
print('模型参数量: ', pytorch_total_params)

# gpt
from transformers import GPT2Config, GPT2Model

config = GPT2Config.from_json_file('gpt2-config.json')
model = GPT2Model(config)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('模型参数量: ', pytorch_total_params)

# gpt

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐