【深度学习】Pytorch模型转成Onnx
工作时需要将模型转成onnx使用triton加载,记录将pytorch模型转成onnx的过程。
·
前言
工作时需要将模型转成onnx使用triton加载,记录将pytorch模型转成onnx的过程。
1.转化步骤
1-1.安装依赖库
pip install onnx
pip install onnxruntime
1-2.导入模型
将训练的模型导入
from torch.utils.data import TensorDataset, DataLoader
from transformers import BertTokenizer, BertModel,AdamW
import torch.nn as nn
import torch
import pandas as pd
import json
import re
import requests
import json
import numpy as np
def encoder(max_length,text_list):
#将text_list embedding成bert模型可用的输入形式
#加载分词模型
vocab_path = "/ssd/dongzhenheng/Pretrain_Model/Roberta_Large/"
#tokenizer = RobertaTokenizer.from_pretrained(vocab_path)
tokenizer = BertTokenizer.from_pretrained(vocab_path)
input_dict = tokenizer.encode_plus(
text,
add_special_tokens=True, # 添加'[CLS]'和'[SEP]'
max_length=max_length,
truncation=True, # 截断或填充
padding='max_length', # 填充至最大长度
return_attention_mask=True, # 返回attention_mask
return_token_type_ids=True, # 返回token_type_ids
return_tensors='pt',
)
input_ids = input_dict['input_ids']
token_type_ids = input_dict['token_type_ids']
attention_mask = input_dict['attention_mask']
print(input_ids.dtype,token_type_ids.dtype,attention_mask.dtype)
input_ids = input_ids.to(torch.int32)
token_type_ids = token_type_ids.to(torch.int32)
attention_mask = attention_mask.to(torch.int32)
print(input_ids.dtype,token_type_ids.dtype,attention_mask.dtype)
return input_ids,token_type_ids,attention_mask
from torch.utils.data import TensorDataset, DataLoader
from transformers import BertTokenizer, BertModel,AdamW
import torch.nn as nn
import torch
import torch.onnx
def encoder(max_length,text_list):
#将text_list embedding成bert模型可用的输入形式
#加载分词模型
vocab_path = '/root/dongzhenheng/ssd/智能客服/models/chinese-roberta-wwm-ext-large/'
tokenizer = BertTokenizer.from_pretrained(vocab_path)
tokenizer = tokenizer(
text_list,
padding = True,
truncation = True,
max_length = max_len,
return_tensors='pt' # 返回的类型为pytorch tensor
)
input_ids = tokenizer['input_ids']
token_type_ids = tokenizer['token_type_ids']
attention_mask = tokenizer['attention_mask']
print(input_ids.dtype,token_type_ids.dtype,attention_mask.dtype)
input_ids = input_ids.to(torch.int32)
token_type_ids = token_type_ids.to(torch.int32)
attention_mask = attention_mask.to(torch.int32)
print(input_ids.dtype,token_type_ids.dtype,attention_mask.dtype)
return input_ids,token_type_ids,attention_mask
class FeedbackBertClassificationModel(nn.Module):
def __init__(self):
super(FeedbackBertClassificationModel, self).__init__()
#加载预训练模型
pretrained_weights = '/root/dongzhenheng/ssd/models/chinese-roberta-wwm-ext-large/'
self.bert = BertModel.from_pretrained(pretrained_weights)
for param in self.bert.parameters():
param.requires_grad = True
self.dropout = nn.Dropout(0.3)
self.pri_dense_1 = nn.Linear(1024, 3)
# 添加批量归一化层
self.bn = nn.BatchNorm1d(3)
def forward(self, input_ids,token_type_ids,attention_mask):
#得到bert_output
bert_output = self.bert(input_ids=input_ids, token_type_ids= token_type_ids,attention_mask=attention_mask)
bert_cls_hidden_state = bert_output[1]
# 应用 dropout
bert_cls_hidden_state = self.dropout(bert_cls_hidden_state)
pri_cls_output_1 = self.pri_dense_1(bert_cls_hidden_state)
# 应用批量归一化
pri_cls_output_1 = self.bn(pri_cls_output_1)
return pri_cls_output_1
FeedBack_classifier_model_path = '/root/dongzhenheng/ssd/智能客服/models/Chinese-roberta-wwm-ext-large_feedback_model_20240904.pth'
FeedBack_classifier_model = FeedbackBertClassificationModel()
# 然后加载模型的状态字典
FeedBack_classifier_model.load_state_dict(torch.load(FeedBack_classifier_model_path,map_location=torch.device('cpu')))
# 设置模型为评估模式
FeedBack_classifier_model.eval()
# 导出模型
max_len = 100
text = '你好'
input_ids, token_type_ids, attention_mask = encoder(max_len,text)
1-3 转成onnx格式
torch.onnx.export(FeedBack_classifier_model, # 模型
(input_ids, token_type_ids, attention_mask), # 模型输入
"/root/dongzhenheng/Work/Triton/智能客服/Onnx_model/model_repository/Feedback_classifition_onnx/1/model.onnx", # 输出文件名
export_params=True, # 是否导出参数
opset_version=11, # ONNX版本
verbose=True,
do_constant_folding=True, # 是否执行常量折叠优化
input_names=["input_ids", "token_type_ids", "attention_mask"], # 输入名
output_names=["pri_cls_output"], # 输出名
dynamic_axes={
"input_ids": {0: "batch_size",1: "seq_length"},
"token_type_ids": {0: "batch_size",1: "seq_length"},
"attention_mask": {0: "batch_size",1: "seq_length"},
"pri_cls_output": {0: "batch_size"}
}
) # 动态维度
model :需要导出的pytorch模型
args:模型的输入参数,需要和模型接收到的参数一致。
path:输出的onnx模型的位置和名称。
export_params:输出模型是否可训练。default=True,表示导出trained model,否则untrained。opset_version :ONNX版本
verbose:是否打印模型转换信息。default=False。
input_names:输入节点名称。default=None。
output_names:输出节点名称。default=None。
do_constant_folding:是否使用常量折叠,默认即可。default=True。
dynamic_axes:模型的输入输出有时是可变的。
1-4 config.pbtxt
name: "Feedback_classifition_onnx"
platform: "onnxruntime_onnx"
max_batch_size: 512
input [
{
name: "input_ids"
data_type: TYPE_INT32
dims: [ -1 ]
},
{
name: "token_type_ids"
data_type: TYPE_INT32
dims: [ -1 ]
},
{
name: "attention_mask"
data_type: TYPE_INT32
dims: [ -1 ]
}
]
output [
{
name: "pri_cls_output"
data_type: TYPE_FP32
dims: [ 3 ]
}
]
instance_group [
{
count: 1
kind: KIND_GPU
gpus: [0]
}
]
version_policy: {
specific: {
versions: [1]
}
}
dynamic_batching {
preferred_batch_size: [32, 64,128,256,512]
max_queue_delay_microseconds: 500
}
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)