MobileNet的onnxruntime推理(C++)

#include <opencv2/opencv.hpp>
#include <onnxruntime/onnxruntime_cxx_api.h>
#include <vector>
#include <iostream>

int main() {

    // load onnx model
    Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "test");
    Ort::SessionOptions sessionOptions;
    sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
    const char* modelPath = "your_path/model.onnx";
    Ort::Session session(env, modelPath, sessionOptions);

    auto start = std::chrono::high_resolution_clock::now();

    // load image
    std::vector<int64_t> inputTensorShape = {1, 3, 224, 224};
    std::vector<float> inputTensorValues(1*3*224*224);

    cv::Mat image = cv::imread("your_path/test/class2/2fb9979f7ff12889ca5a69e031b93541.jpg");
    cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
    cv::resize(image, image, cv::Size(224, 224));
    image.convertTo(image, CV_32F, 1.0 / 255);
    constexpr float meanVal[3] = {0.485f, 0.456f, 0.406f};
    constexpr float stdVal[3] = {0.229f, 0.224f, 0.225f};

    const int height = image.rows;
    const int width = image.cols;
    const int channels = image.channels();
    std::vector<float> outputBuffer;
    for (int c = 0; c < channels; ++c) {
        for (int h = 0; h < height; ++h) {
            for (int w = 0; w < width; ++w) {
                float pixelValue = image.at<cv::Vec3f>(h, w)[c];
                pixelValue = (pixelValue - meanVal[c]) / stdVal[c];
                inputTensorValues[c * (width * height) + h * width + w] = pixelValue;
            }
        }
    }

    // predict
    std::vector<Ort::Value> inputTensors;
    std::vector<const char*> inputNames = {"input.1"};
    Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
    inputTensors.emplace_back(Ort::Value::CreateTensor<float>(memoryInfo, inputTensorValues.data(), inputTensorValues.size(), inputTensorShape.data(), inputTensorShape.size()));

    std::vector<const char*> outputNames = {"400"};
    auto outputTensors = session.Run(Ort::RunOptions{nullptr}, inputNames.data(), inputTensors.data(), 1, outputNames.data(), 1);
    auto* floatArray = outputTensors[0].GetTensorMutableData<float>();
    size_t numItems = outputTensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
    std::vector<std::string> classNames = {"class1", "class2"};
    int maxIndex = std::distance(floatArray, std::max_element(floatArray, floatArray + numItems));
    std::cout << "Predicted class: " << classNames[maxIndex] << " with confidence " << floatArray[maxIndex] << std::endl;


    auto end = std::chrono::high_resolution_clock::now();    
    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
    std::cout << "cost: " << duration.count() << " ms" << std::endl;

    return 0;
}

模型是mobilenet v3,这里有一些outputname、inputname等是导出时决定的,如果不确定是啥,可以以下代码输出:

// print model
Ort::AllocatorWithDefaultOptions allocator;

std::shared_ptr<char> inputName = std::move(session.GetInputNameAllocated(0, allocator));
std::vector<char*> inputNodeNames;
inputNodeNames.push_back(inputName.get());
std::cout << "输入节点名称:" << inputName << "\n";
 
 
Ort::TypeInfo inputTypeInfo = session.GetInputTypeInfo(0);
auto input_tensor_info = inputTypeInfo.GetTensorTypeAndShapeInfo();
ONNXTensorElementDataType inputNodeDataType = input_tensor_info.GetElementType();
std::vector<int64_t> inputTensorShape2 = input_tensor_info.GetShape();
std::cout << "输入节点shape:";
for (int i = 0; i<inputTensorShape2.size(); i++)
{
    std::cout << inputTensorShape2[i]<<" ";
}
std::cout << std::endl;
               
size_t outputNodeCount = session.GetOutputCount();
std::cout << "输出节点数量:" << outputNodeCount << "\n";
 
std::shared_ptr<char> outputName = std::move(session.GetOutputNameAllocated(0, allocator));
std::vector<char*> outputNodeNames;
outputNodeNames.push_back(outputName.get());
std::cout << "输出节点名称:" << outputName << "\n";

更新:不确定模型的层名称,也可以用这个在线工具:https://netron.app/

 

Leave a Reply

Your email address will not be published. Required fields are marked *