import os
import cv2
import numpy as np
import requests
import torch
import torch.onnx
from torch import nn
import onnx
import onnxruntime
from torch.nn.functional import interpolate

# 模型导出方法:跟踪法和记录法
# 跟踪法只能通过实际运行一遍模型的方法导出模型的静态图,即无法识别出模型中的控制流(如循环);
# 记录法则能通过解析模型来正确记录所有的控制流
class Model(torch.nn.Module):
    def __init__(self, n):
        super().__init__()
        self.n = n
        self.conv = torch.nn.Conv2d(3, 3, 3)
    # 带循环模型
    def forward(self, x):
        for i in range(self.n):
            x = self.conv(x)
        return x

def test_script_and_trace():
    models = [Model(2), Model(3)]
    model_names = ['model_2', 'model_3']

    for model, model_name in zip(models, model_names):
        dummy_input = torch.rand(1, 3, 10, 10)
        dummy_output = model(dummy_input)
        model_trace = torch.jit.trace(model, dummy_input)
        model_script = torch.jit.script(model)

        # 跟踪法与直接 torch.onnx.export(model, ...)等价
        torch.onnx.export(model_trace, dummy_input, f'{model_name}_trace.onnx', example_outputs=dummy_output)
        # 记录法必须先调用 torch.jit.sciprt
        torch.onnx.export(model_script, dummy_input, f'{model_name}_script.onnx', example_outputs=dummy_output)

if __name__ == '__main__':
    test_script_and_trace()

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐