MobileNet v3的Java推理

模型是MobileNet v3 small + 微调分类器,pyTorch训练后导出onnx模型

pom:

<dependencies>
    <dependency>
        <groupId>com.microsoft.onnxruntime</groupId>
        <artifactId>onnxruntime</artifactId>
        <version>1.17.3</version>
    </dependency>
</dependencies>

代码:

package org.example;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;

import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.HashMap;

public class OnnxModelInference {

    private OrtSession session;

    public OnnxModelInference(String modelPath) throws Exception {
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        byte[] modelArray = Files.readAllBytes(Paths.get(modelPath));
        session = env.createSession(modelArray, sessionOptions);
    }

    public float[][][][] preprocessImage(String imagePath) throws IOException {
        // read img
        BufferedImage image = ImageIO.read(new File(imagePath));

        // resize to 224
        BufferedImage resizedImage = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB);
        Graphics2D g2d = resizedImage.createGraphics();
        g2d.setColor(Color.WHITE); // 或根据需要选择背景颜色
        g2d.fillRect(0, 0, 224, 224);
        g2d.drawImage(image, 0, 0, 224, 224, null);
        g2d.dispose();

        // ImageNet norm
        float[] mean = {0.485f, 0.456f, 0.406f};
        float[] std = {0.229f, 0.224f, 0.225f};

        float[][][][] floatArray = new float[1][3][224][224];
        for (int y = 0; y < 224; y++) {
            for (int x = 0; x < 224; x++) {
                Color color = new Color(resizedImage.getRGB(x, y));
                floatArray[0][0][y][x] = (color.getRed() / 255.0f - mean[0]) / std[0]; // R
                floatArray[0][1][y][x] = (color.getGreen() / 255.0f - mean[1]) / std[1]; // G
                floatArray[0][2][y][x] = (color.getBlue() / 255.0f - mean[2]) / std[2]; // B
            }
        }
        return floatArray;
    }

    public String predict(String imagePath) throws Exception {
        float[][][][] inputData = preprocessImage(imagePath);

        OnnxTensor inputTensor = OnnxTensor.createTensor(OrtEnvironment.getEnvironment(), inputData);
        HashMap<String, OnnxTensor> inputs = new HashMap<>();
        inputs.put("input.1", inputTensor); // "input"是模型输入节点的名称,请根据实际情况调整

        // num of class names
        String[] classes = {"c1", "c2"....};

        // predict
        try (OrtSession.Result results = session.run(inputs)) {
            float[][] output = (float[][]) results.get(0).getValue();
            int maxIndex = 0;
            for (int i = 1; i < output[0].length; i++) {
                maxIndex = output[0][i] > output[0][maxIndex] ? i : maxIndex;
            }
            return classes[maxIndex]; // i -> class name
        }
    }

    public static void main(String[] args) {
        try {
            OnnxModelInference inferencer = new OnnxModelInference("path/tray.onnx");

            // traverse dir with recursive
            String dir = "path/img_test/";
            long start = System.currentTimeMillis();
            long cnt = 0;
            for (File file : new File(dir).listFiles()) {
                if (file.isDirectory()) {
                    for (File subFile : file.listFiles()) {
                        String result = inferencer.predict(subFile.getAbsolutePath());
                        System.out.println(subFile.getAbsolutePath() + " : " + result);
                        cnt++;
                    }
                }
            }
            long end = System.currentTimeMillis();
            System.out.println("Cost Avg: " + (end - start) / cnt + "ms");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

速度还挺快,在我的i7-12700,纯cpu是15ms / 张

Leave a Reply

Your email address will not be published. Required fields are marked *