pytorch的pth模型转换为onnx
文章目录前言一、ONNX是什么?二、使用方法总结前言将pytorch的pth模型转换为onnx提示:以下是本篇文章正文内容,下面案例可供参考一、ONNX是什么?Open Neural Network Exchange(ONNX,开放神经网络交换)格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移,ONNX是一种针对机器学习所设计的开放式的文件格式,用于存储训练好的模型。它使得不
·
文章目录
前言
将pytorch的pth模型转换为onnx
提示:以下是本篇文章正文内容,下面案例可供参考
一、ONNX是什么?
Open Neural Network Exchange(ONNX,开放神经网络交换)格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移,ONNX是一种针对机器学习所设计的开放式的文件格式,用于存储训练好的模型。它使得不同的人工智能框架(如Pytorch, MXNet)可以采用相同格式存储模型数据并交互。
二、使用方法
import cv2
import torch
from models.CC import CrowdCounter
from torch.autograd import Variable
import torchvision.transforms as standard_transforms
def pth_to_onnx():
net= CrowdCounter([0], 'res50') #加载神经网络
#注意查看net的结构
pthfile = '/vgg16.pth' #输入pth模型
model_test = torch.load(pthfile) ##加载多卡训练后的pth模型
net.load_state_dict({k.replace(k, 'CCN.' + k): v for k, v in model_test.items()})
net.cuda()
net.eval()
image = cv2.imread('test1.jpg')
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
img = img_transform(img)
img = img.view(1, img.shape[0], img.shape[1], img.shape[2])
with torch.no_grad(): #不需要降低梯度
img = Variable(img).cuda()
input_names = ["feature4"] #只代表输入节点名称
output_names = ["de_pred"] #只代表输出节点名称
#只能输入卷积网络的内容,类似回归损失之类的除外
torch.onnx.export(net.CCN, img, "res50.onnx", verbose=True, input_names=input_names,
output_names=output_names)
```bash
Res50(
# (de_pred): Sequential(
# (0): Conv2d(
# (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1))
# (relu): ReLU(inplace=True)
# )
# (1): Conv2d(
# (conv): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1))
# (relu): ReLU(inplace=True)
# )
# )
# (frontend): Sequential(
# (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU(inplace=True)
# (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
# (4): Sequential(
# (0): Bottleneck(
# (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
# (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
# (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (relu): ReLU(inplace=True)
# (downsample): Sequential(
# (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
# (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# )
# )
# (1): Bottleneck(
#
# )
# (2): Bottleneck(
# (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
# (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
# (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (relu): ReLU(inplace=True)
# )
# )
# (5): Sequential(
# (0): Bottleneck(
# (conv1): Conv2d(256, 128, kernel_size=(1,
总结
上述代码可以很方便的转换pth为onnx结构,需要注意torch.onnx.export中输入的网络
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐

所有评论(0)