MobileNetV3的oxxruntime推理(Python)

import cv2
import numpy as np
import onnxruntime as ort

def load_model(model_path):
    """加载ONNX模型"""
    session = ort.InferenceSession(model_path)
    return session

def preprocess_image(image_path):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (224, 224))
    image = image.astype(np.float32) / 255.0  # 归一化至0-1之间

    # 标准化参数
    mean_val = np.array([0.485, 0.456, 0.406], dtype=np.float32)
    std_val = np.array([0.229, 0.224, 0.225], dtype=np.float32)

    # 标准化图片
    image = (image - mean_val) / std_val

    # 调整图片形状为 [1, 3, 224, 224]
    image = np.transpose(image, (2, 0, 1))  # 通道、高度、宽度
    image = np.expand_dims(image, axis=0)  # 批次大小
    return image

def predict(session, input_image):
    """对输入图片进行预测"""
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    result = session.run([output_name], {input_name: input_image})
    return result

def main(model_path, image_path):
    # 加载模型
    session = load_model(model_path)

    # 读取和预处理图片
    img = preprocess_image(image_path)

    # 做预测
    prediction = predict(session, img)

    # 取得预测的类别
    predicted_class_index = np.argmax(prediction[0])
    class_names = ['class1', 'class2']
    if predicted_class_index < len(class_names):
        predicted_class_name = class_names[predicted_class_index]
        print(f"Predicted class: {predicted_class_name}")
    else:
        print("Predicted class index is out of bounds of the class_names array")

model_path = 'your_path/model.onnx'
image_path = 'your_path/test/class1/03a79bee1d5ac929f9b18b4b223aad04.jpg'
main(model_path, image_path)

 

Leave a Reply

Your email address will not be published. Required fields are marked *