目前移动端常见的部署神经网络模型格式有onnx,也有ncnn,或者直接使用torch运行时执行pt文件,今天简单的讲一下如何直接部署onnx格式的yolov8模型,相比ncnn格式的转换,onnx模式不需要繁琐的手动下载第三方库,吐槽一下ncnn准备步骤实在是又臭又长,对萌新极不友好。

先声明本文不包括java版的NMS数据清理,如有需要,自己照下面文章的python代码自行转换

首先,我们已经有了yolov8模型的onnx格式,转换详情见我下面的文章pytorch下yolov8打包onnx模型并使用NMS对数据清洗(最后有完整代码)_onnx runtime + yolo + pytroch-CSDN博客https://blog.csdn.net/qq_64809150/article/details/140436333?spm=1001.2014.3001.5501

目录

一、导入第三方依赖:

二、创建onnx模型处理类

三、创建图像处理类

四、创建activity进行测试

五、结果展示


 

一、导入第三方依赖:

implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'

 将上述代码添加至build.gradle中的dependencies下面

c8de03fecd2f4fc686d260a8fae0df43.png5639e524046b462eb522b4555b10b22f.png

二、创建onnx模型处理类

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtSession.Result;
import ai.onnxruntime.OrtException;

import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class ModelInference {
    private OrtEnvironment env;
    private OrtSession session;

    public ModelInference(String modelPath) throws OrtException {
        env = OrtEnvironment.getEnvironment();//创建环境
        session = env.createSession(modelPath);//创建会话
    }
    //运行模型推理
    public float[][] runInference(float[] inputData) throws OrtException {
        long[] inputShape = {1, 3, 640, 640};
        OnnxTensor inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), inputShape);

        Result output = session.run(Collections.singletonMap("images", inputTensor));
        float[][][] outputArray = (float[][][]) output.get(0).getValue();

        return parseAndTransposeOutput(outputArray);
    }
    //数据处理
    private static float[][] parseAndTransposeOutput(float[][][] outputArray) {
        int numDetections = outputArray[0][0].length;  // 8400
        int detectionLength = outputArray[0].length;   // 7

        List<float[]> validDetections = new ArrayList<>(numDetections / 2);

        for (int i = 0; i < numDetections; i++) {
            if (outputArray[0][4][i] > 0.5) {
                float[] detection = new float[detectionLength];
                for (int j = 0; j < detectionLength; j++) {
                    detection[j] = outputArray[0][j][i];
                }
                validDetections.add(detection);
            }
        }

        float[][] transposedArray = new float[validDetections.size()][detectionLength];
        for (int i = 0; i < validDetections.size(); i++) {
            transposedArray[i] = validDetections.get(i);
        }

        return transposedArray;
    }

    public void close() throws OrtException {
        if (session != null) {
            session.close();
        }
        if (env != null) {
            env.close();
        }
    }
}
import android.content.Context;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;


public class ReadAssets {
    public static String assetFilePath(Context context, String assetName) throws IOException {
        File file = new File(context.getFilesDir(), assetName);
        if (file.exists() && file.length() > 0) {
            return file.getAbsolutePath();
        }

        try (InputStream is = context.getAssets().open(assetName)) {
            try (OutputStream os = new FileOutputStream(file)) {
                byte[] buffer = new byte[4 * 1024];
                int read;
                while ((read = is.read(buffer)) != -1) {
                    os.write(buffer, 0, read);
                }
                os.flush();
            }
            return file.getAbsolutePath();
        }
    }

}

先创建onnx环境和会话,我的模型都放在assets目录下面,这个目录需要自己创建,上面给出获取方法了,使用environment.createSession创建会话,顺便提一嘴,options用于为onnx模型会话配置特定的参数和设置,影响会话的行为和性能,也可以不写,会自动使用默认值,下面是该方法的官方文档截图(参照官方文档你怎么处理都可以)

777087fcc9284e04b2375e52a5529ff6.png

0ac87ca76db24776a62524372c2ef8ba.png

上述runInference方法对输入数据进行推理,再使用parseAndTransposeOutput方法将数据进行转置操作便于后续数据清洗,如果你看过我上一篇文章就可以猜到这个方法其实是为后面的NMS清洗数据噪声做准备。

三、创建图像处理类

处理输入图像格式,onnx官方给出的案例模板中已经有现成的kotlin代码,但是萌新使用java进行安卓开发时不会使用kotlin,所以我写了一份java版的处理代码给大家参考。

import android.graphics.Bitmap;
import android.graphics.Color;

public class ImageUtils {

    public static Bitmap resizeImage(Bitmap bitmap, int width, int height) {
        return Bitmap.createScaledBitmap(bitmap, width, height, true);
    }

    public static float[] preprocessImage(Bitmap bitmap, int inputSize) {
        int width = bitmap.getWidth();
        int height = bitmap.getHeight();
        float resizeRatio = (float) inputSize / Math.max(width, height);
        Bitmap resizedBitmap = resizeImage(bitmap, (int) (width * resizeRatio), (int) (height * resizeRatio));
        int resizedWidth = resizedBitmap.getWidth();
        int resizedHeight = resizedBitmap.getHeight();

        int[] pixels = new int[resizedWidth * resizedHeight];
        resizedBitmap.getPixels(pixels, 0, resizedWidth, 0, 0, resizedWidth, resizedHeight);

        float[] floatValues = new float[inputSize * inputSize * 3];
        int channelOffset = inputSize * inputSize;
        int pixelIndex = 0;

        for (int i = 0; i < inputSize; ++i) {
            for (int j = 0; j < inputSize; ++j) {
                if (i < resizedHeight && j < resizedWidth) {
                    int pixelValue = pixels[pixelIndex++];
                    floatValues[i * inputSize + j] = Color.red(pixelValue) / 255.0f;
                    floatValues[channelOffset + i * inputSize + j] = Color.green(pixelValue) / 255.0f;
                    floatValues[channelOffset * 2 + i * inputSize + j] = Color.blue(pixelValue) / 255.0f;
                } else {
                    floatValues[i * inputSize + j] = 0;
                    floatValues[channelOffset + i * inputSize + j] = 0;
                    floatValues[channelOffset * 2 + i * inputSize + j] = 0;
                }
            }
        }
        return floatValues;
    }
}

四、创建activity进行测试

xml文件就一个ImageView,自己重命名就可以了,我自己写了一个NMS算法对推理后的数据进行处理,最后可以完美显示结果,这部分代码我已经用python的形式在上一篇文章里面给出来了,链接见本文开头。

NMS这部分还是辛苦大家自己参照python代码写吧,我代码功能是根据需求定制的所以内容又多又杂,最近抽不出时间精简,对不住大伙了,望理解。此外没有其他缺失代码了。

import static com.example.growth.modelUtil.yolo.NMS.drawBoundingBoxes;
import static com.example.growth.modelUtil.yolo.NMS.nms;

import android.annotation.SuppressLint;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.widget.CompoundButton;
import android.widget.ImageView;
import android.widget.Switch;
import android.widget.TextView;

import androidx.appcompat.app.AppCompatActivity;

import com.example.growth.modelUtil.ReadAssets;
import com.example.growth.modelUtil.yolo.ImageUtils;
import com.example.growth.modelUtil.yolo.ModelInference;
import com.example.growth.modelUtil.yolo.NMS;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.List;
import java.util.stream.IntStream;

import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;

public class Yolov8Test extends AppCompatActivity {
    private static final String MODEL_PATH = "yolov8_best.onnx"; // 模型文件路径
    private static final int INPUT_SIZE = 640;
    private ImageView imageView;

    @SuppressLint("MissingInflatedId")
    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_yolov8_test);

        imageView = findViewById(R.id.imageView);
        open();
    }
    private void open(){
        try {
            // 加载图像
            InputStream inputStream = getAssets().open("0304.JPG");
            Bitmap bitmap = BitmapFactory.decodeStream(inputStream);

            // 预处理图像
            Bitmap resizedBitmap = ImageUtils.resizeImage(bitmap, INPUT_SIZE, INPUT_SIZE);
            float[] inputData = ImageUtils.preprocessImage(resizedBitmap, INPUT_SIZE);

            // 加载模型并进行推理
            ModelInference modelInference = new ModelInference(ReadAssets.assetFilePath(this,MODEL_PATH));//getAssetFilePath(MODEL_PATH)
            float[][] new_array=modelInference.runInference(inputData);
            List<float[]> nmsDetections = nms(new_array, 0.1f);
            float[] list=drawBoundingBoxes(resizedBitmap, nmsDetections);
            float max=findMaxValueUsingStream(list);

            // 显示结果图像
            imageView.setImageBitmap(resizedBitmap);
        } catch (IOException | OrtException e) {
            e.printStackTrace();
        }
    }

五、结果展示

a8e993c6ffd24f15ac1ec06c1a783ac7.png

 

 

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐