安装tensorRT:

1、下载与电脑中cuda和cudnn版本对应的tensorRT(比如我的是TensorRT-8.2.1.8.Windows10.x86_64.cuda-11.4.cudnn8.2)

2、打开目录里面有python文件夹,找到对应python版本的whl文件(我的是tensorrt-8.2.1.8-cp38-none-win_amd64.whl)  因为我python安装的是3.8版本

3、终端安装:pip install tensorrt-8.2.1.8-cp38-none-win_amd64.whl

4、结束


import tensorrt as trt
def get_DynEngine(onnx_file_path, engine_file_path,patchsize,max_workspace_size,max_batch_size):
    '''
    Attempts to load a serialized engine if available,
    otherwise build a new TensorRT engine as save it
    '''
    TRT_LOGGER = trt.Logger()
    trt.init_libnvinfer_plugins(TRT_LOGGER, "")
    explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network(explicit_batch)
    config = builder.create_builder_config()
    parser = trt.OnnxParser(network, TRT_LOGGER)
    runtime = trt.Runtime(TRT_LOGGER)
    print("common.EXPLICIT_BATCH:", explicit_batch)
    # 最大内存占用
    # 显存溢出需要重新设置
    config.max_workspace_size = max_workspace_size # 256MB
    config.set_flag(trt.BuilderFlag.FP16)
    print("max_workspace_size:", config.max_workspace_size)
    builder.max_batch_size = max_batch_size  # 推理的时候要保证batch_size<=max_batch_size
    
    if not os.path.exists(onnx_file_path):
        print(f'onnx file {onnx_file_path} not found,please run torch_2_onnx.py first to generate it')
        exit(0)
    print(f'Loading ONNX file from path {onnx_file_path}...')
    with open(onnx_file_path, 'rb') as model:
        print('Beginning ONNX file parsing')
        if not parser.parse(model.read()):
            print('ERROR:Failed to parse the ONNX file')
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            return None
    inputs = [network.get_input(i) for i in range(network.num_inputs)]
    print("input", inputs)

    outputs = [network.get_output(i) for i in range(network.num_outputs)]
    print("out:", outputs)
    print("Network Description")
    for input in inputs:
        # 获取当前转化之前的 输入的 batch_size
        batch_size = input.shape[0]
        print("Input '{}' with shape {} and dtype {} . ".format(input.name, input.shape, input.dtype))
    for output in outputs:
        print("Output '{}' with shape {} and dtype {} . ".format(output.name, output.shape, output.dtype))
    # Dynamic input setting 动态输入在builder的profile设置
    # 为每个动态输入绑定一个profile
    profile = builder.create_optimization_profile()
    print("network.get_input(0).name:", network.get_input(0).name)
    profile.set_shape(network.get_input(0).name, (1,1, *patchsize), (1, 1,*patchsize),
                      (max_batch_size, 1, *patchsize))  # 最小的尺寸,常用的尺寸,最大的尺寸,推理时候输入需要在这个范围内
    config.add_optimization_profile(profile)
    print('Completed parsing the ONNX file')
    print(f'Building an engine from file {onnx_file_path}; this may take a while...')
    
    engine = builder.build_serialized_network(network, config)
    print('Completed creating Engine')
    with open(engine_file_path, 'wb') as f:
        f.write(engine)
    return engine



if __name__ == "__main__":

     get_DynEngine("1.onnx", "2.engine",[96,160,160],5*(1<<30),2)
                    
    

    

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐