模型转换支持多维度动态batch设置
import torch
import onnxruntime
import numpy as np

# 模型转换支持多维度动态batch设置
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 3, 3)

    def forward(self, x):
        x = self.conv(x)
        return x

def test_dynamic_axes():
    model = Model()
    dummy_input = torch.rand(1, 3, 10, 10)
    model_names = ['model_static.onnx',
                   'model_dynamic_0.onnx',
                   'model_dynamic_23.onnx']
    # 第0维动态
    dynamic_axes_0 = {
        'in': {0: 'batch'},
        'out': {0: 'batch'}
    }
    # 第2,3维动态batch
    dynamic_axes_23 = {
        'in': {2: 'batch', 3: 'batch'},
        'out': {2: 'batch', 3: 'batch'}
    }

    torch.onnx.export(model,
                      dummy_input,
                      model_names[0],
                      input_names=['in'],
                      output_names=['out'])

    torch.onnx.export(model,
                      dummy_input,
                      model_names[1],
                      input_names=['in'],
                      output_names=['out'],
                      dynamic_axes=dynamic_axes_0)

    torch.onnx.export(model,
                      dummy_input,
                      model_names[2],
                      input_names=['in'],
                      output_names=['out'],
                      dynamic_axes=dynamic_axes_23) #指定输入输出张量的哪些维度是动态的


def test_dynamic_and_static_model_export():
    model = Model()
    origin_tensor = np.random.rand(1, 3, 10, 10).astype(np.float32)
    mult_batch_tensor = np.random.rand(2, 3, 10, 10).astype(np.float32)
    big_tensor = np.random.rand(1, 3, 20, 20).astype(np.float32)

    model_names = ['model_static.onnx',
                   'model_dynamic_0.onnx',
                   'model_dynamic_23.onnx']
    inputs = [origin_tensor, mult_batch_tensor, big_tensor]
    exceptions = dict()

    for model_name in model_names:
        for i, input in enumerate(inputs):
            try:
                ort_session = onnxruntime.InferenceSession(model_name)
                ort_inputs = {'in': input}
                ort_session.run(['out'], ort_inputs)
            except Exception as e:
                exceptions[(i, model_name)] = e
                print(f'Input[{i}] on model {model_name} error.')
            else:
                print(f'Input[{i}] on model {model_name} succeed.')

if __name__ == '__main__':
    # test_dynamic_axes()
    test_dynamic_and_static_model_export()

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐