Artificial Intelligence (AI)
Discuss current events in AI and technological innovations with Intel® employees
485 Discussions

The Example of Deploying YOLOv7 Pre-trained Model Based on the OpenVINO™ 2022.1 C++ API

Stephanie_Maluso
Employee
0 0 8,093

Author: Ethan Yang

Task Background

As one of the most common tasks in visual applications, object detection has always been a place of strategic importance for all kinds of new models, among which the most prominent one is the network structure of the YOLO series. YOLO stands for “You only look once”, which means you can identify the category and location of an object in a picture only by “browsing once” through one stage. Recently, the YOLO official team released a new version, YOLOv7, which has surpassed other variants in speed and accuracy. This article will share how to deploy the YOLOv7 official pre-trained model based on the OpenVINO™ 2022.1 tool suite. The C++/Python source code and usage are attached.

Code repository:https://github.com/OpenVINO-dev-contest/YOLOv7_OpenVINO_cpp-python 

OpenVINO™ Introduction

Powered by oneAPI, the Intel® Distribution of OpenVINO™ for high-performance deep learning was developed to help users deploy more accurate results into production systems across a variety of Intel platforms from edge to cloud. With a simple development workflow, OpenVINO™ empowers developers to deploy high-performance applications and algorithms in the real world.

On the inference back-end, thanks to the “program once, deploy anywhere” feature provided by the OpenVINO™ toolkit, the converted model can be run on different Intel hardware platforms without rebuilding, which effectively simplifies the build and migration process. In addition, in order to support more heterogeneous acceleration units, OpenVINO™ Runtime API uses a plug-in development architecture based on the MKL-DNN and oneDNN, which is optimized for common instruction sets such as AVX-512. A complete set of high-performance operator libraries are implemented for different hardware execution units, which can improve the overall performance of the model during inferencing.

YOLOv7 Introduction

Based on the same volume, the official version of YOLOv7 is more accurate and 120% faster (FPS) than YOLOv5, 180% faster than YOLOX (FPS), 1200% faster than Dual-Swin-T (FPS), 550% faster than ConvNext ( FPS), and 500% faster (FPS) than SWIN-L. In the range of 5FPS to 160FPS, either in speed or accuracy, YOLOv7 exceeds the currently known detectors. It has been tested on GPU V100, and the model with an accuracy of 56.8% AP can reach the detection rate above 30 FPS (batch=1). At the same time, it is currently the only detector that can still exceed 30FPS with such high accuracy.

Task Development Process

Let’s take a look at the input and output structure of YOLOv7 as a whole. First, the size of the input image is resized to 640x640 and input to the backbone network. Then the feature map and prediction results with different sizes of the three layers are output through the head layer network.

Take COCO datasets as an example, the output is 80 categories, and each output (x, y, w, h, o) is the coordinate position and the confidence of the existence of an object. 3 refers to the number of anchors. The output of each layer is (80+5) x 3 = 255, multiplying the size of the feature map to get the final output. The entire development process can be divided into four parts: data processing module definition, pre-processing tasks, inferencing tasks, and post-processing tasks.

Figure: The Input and Output Structure of the Official Pre-training Model of YOLOv7Figure: The Input and Output Structure of the Official Pre-training Model of YOLOv7

 

1. Data Processing Module

An Object structure is defined to store the model’s output data, containing bounding box information, class labels, and their cumulative confidence in the existence of objects and classes.

Defines the CLASS_NAMES vector to store all the labels of the COCO dataset.

struct Object
{
cv::Rect_<float> rect;
int label;
float prob;
};

const std::vector<std::string> class_names = {
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
"hair drier", "toothbrush"};

Define letterbox and scale_box modules, which are respectively used to add letterbox to input data in image pre-processing tasks, and to restore the coordinate position transformation from letterbox in post-processing tasks. In particular, we have added a padd vector to store the size information of the letterbox and the scaling information compared to the original image during the process of adding the letterbox. This set of data will be used to restore the result after deleting the letterbox in the post-processing task.

cv::Mat letterbox(cv::Mat &src, int h, int w, std::vector<float> &padd)
{
// Resize and pad image while meeting stride-multiple constraints
int in_w = src.cols;
int in_h = src.rows;
int tar_w = w;
int tar_h = h;
float r = min(float(tar_h) / in_h, float(tar_w) / in_w);
int inside_w = round(in_w * r);
int inside_h = round(in_h * r);
int padd_w = tar_w - inside_w;
int padd_h = tar_h - inside_h;
cv::Mat resize_img;

// resize
resize(src, resize_img, cv::Size(inside_w, inside_h));

// divide padding into 2 sides
padd_w = padd_w / 2;
padd_h = padd_h / 2;
padd.push_back(padd_w);
padd.push_back(padd_h);

// store the ratio
padd.push_back(r);
int top = int(round(padd_h - 0.1));
int bottom = int(round(padd_h + 0.1));
int left = int(round(padd_w - 0.1));
int right = int(round(padd_w + 0.1));

// add border
copyMakeBorder(resize_img, resize_img, top, bottom, left, right, 0, cv::Scalar(114, 114, 114));
return resize_img;
}

cv::Rect scale_box(cv::Rect box, std::vector<float> &padd)
{
// remove the padding area
cv::Rect scaled_box;
scaled_box.x = box.x - padd[0];
scaled_box.y = box.y - padd[1];
scaled_box.width = box.width;
scaled_box.height = box.height;
return scaled_box;
}

Define the generate_proposals module with the following functions:

  1. According to the predefined anchors, generate proposals boxes of various feature maps in the input image;
  2. Adjust the position and size of the proposals box according to the output results, and restore them to the coordinate system of the input image as a bounding box;
  3. Filter the classification results with low confidence to obtain the classification results;
static void generate_proposals(int stride, const float *feat, float prob_threshold, std::vector<Object> &objects)
{
// get the results from proposals
float anchors[18] = {12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401};
int anchor_num = 3;
int feat_w = 640 / stride;
int feat_h = 640 / stride;
int cls_num = 80;
int anchor_group = 0;
if (stride ==
anchor_group = 0;
if (stride == 16)
anchor_group = 1;
if (stride == 32)
anchor_group = 2;

// 3 x h x w x (80 + 5)
for (int anchor = 0; anchor <= anchor_num - 1; anchor++)
{
for (int i = 0; i <= feat_h - 1; i++)
{
for (int j = 0; j <= feat_w - 1; j++)
{
float box_prob = feat[anchor * feat_h * feat_w * (cls_num + 5) + i * feat_w * (cls_num + 5) + j * (cls_num + 5) + 4];
box_prob = sigmoid(box_prob);

// filter the bounding box with low confidence
if (box_prob < prob_threshold)
continue;
float x = feat[anchor * feat_h * feat_w * (cls_num + 5) + i * feat_w * (cls_num + 5) + j * (cls_num + 5) + 0];
float y = feat[anchor * feat_h * feat_w * (cls_num + 5) + i * feat_w * (cls_num + 5) + j * (cls_num + 5) + 1];
float w = feat[anchor * feat_h * feat_w * (cls_num + 5) + i * feat_w * (cls_num + 5) + j * (cls_num + 5) + 2];
float h = feat[anchor * feat_h * feat_w * (cls_num + 5) + i * feat_w * (cls_num + 5) + j * (cls_num + 5) + 3];

double max_prob = 0;
int idx = 0;

// get the class id with maximum confidence
for (int t = 5; t < 85; ++t)
{
double tp = feat[anchor * feat_h * feat_w * (cls_num + 5) + i * feat_w * (cls_num + 5) + j * (cls_num + 5) + t];
tp = sigmoid(tp);
if (tp > max_prob)
{
max_prob = tp;
idx = t;
}
}

// filter the class with low confidence
float cof = box_prob * max_prob;
if (cof < prob_threshold)
continue;

// convert results to xywh
x = (sigmoid(x) * 2 - 0.5 + j) * stride;
y = (sigmoid(y) * 2 - 0.5 + i) * stride;
w = pow(sigmoid(w) * 2, 2) * anchors[anchor_group * 6 + anchor * 2];
h = pow(sigmoid(h) * 2, 2) * anchors[anchor_group * 6 + anchor * 2 + 1];

float r_x = x - w / 2;
float r_y = y - h / 2;

// store the results
Object obj;
obj.rect.x = r_x;
obj.rect.y = r_y;
obj.rect.width = w;
obj.rect.height = h;
obj.label = idx - 5;
obj.prob = cof;
objects.push_back(obj);
}
}
}
}

 2. Pre-processing Tasks

The pre-processing mainly includes the following steps

  1. Use OpenCV to read image files;
  2. Resize the original image and add a letterbox;
  3. Convert the color channel from BGR to RGB;
  4. Layout the input data (NHWC=>NCHW), and normalize it (see the code in the model inference section);
    cv::Mat src_img = cv::imread(image_path);
cv::Mat img;

std::vector<float> padd;
cv::Mat boxed = letterbox(src_img, img_h, img_w, padd);
cv::cvtColor(boxed, img, cv::COLOR_BGR2RGB);

3. Inference Tasks

ai-blog-yolov7-fig02.png

Figure: The Development Process of OpenVINOTM Toolkit Runtime

The model inference part mainly calls the C++ API of OpenVINO™ for implementation. The calling process of the OpenVINO™ inference interface is shown in the figure above, in which the second step of the Compile Model can also be divided into two steps: model reading and model compilation. During the whole process, the developer needs to lay out the input data (NHWC=>NCHW) and fill it into the data pointer address which corresponded to the input tensor. In the result data extraction part, since this model has 3 different scale feature map outputs, we need to obtain their result data pointers one by one.

// -------- Step 1. Initialize OpenVINO Runtime Core --------
ov::Core core;

// -------- Step 2. Read a model --------
std::shared_ptr<ov::Model> model = core.read_model(model_path);

// -------- Step 3. Loading a model to the device --------
ov::CompiledModel compiled_model = core.compile_model(model, device_name);

// Get input port for model with one input
auto input_port = compiled_model.input();
// Create tensor from external memory
// ov::Tensor input_tensor(input_port.get_element_type(), input_port.get_shape(), input_data.data());
// -------- Step 4. Create an infer request --------
ov::InferRequest infer_request = compiled_model.create_infer_request();

// -------- Step 5. Prepare input --------
ov::Tensor input_tensor1 = infer_request.get_input_tensor(0);
// NHWC => NCHW
auto data1 = input_tensor1.data<float>();
for (int h = 0; h < img_h; h++)
{
for (int w = 0; w < img_w; w++)
{
for (int c = 0; c < 3; c++)
{
// int in_index = h * img_w * 3 + w * 3 + c;
int out_index = c * img_h * img_w + h * img_w + w;
data1[out_index] = float(img.at<cv::Vec3b>(h, w)[c]) / 255.0f;
}
}
}

// -------- Step 6. Start inference --------
infer_request.infer();

// -------- Step 7. Process output --------
auto output_tensor_p8 = infer_request.get_output_tensor(0);
const float *result_p8 = output_tensor_p8.data<const float>();
auto output_tensor_p16 = infer_request.get_output_tensor(1);
const float *result_p16 = output_tensor_p16.data<const float>();
auto output_tensor_p32 = infer_request.get_output_tensor(2);
const float *result_p32 = output_tensor_p32.data<const float>();

4.Post-processing Tasks

The post-processing part needs to call the generate_proposals we defined before, in order to restore the result data of each feature map and stack them. Finally, we use the NMS method in OpenCV DNN module to complete the non-maximum suppression filtering of the bounding box result and obtain target location and category information in the original input image.

generate_proposals(8, result_p8, prob_threshold, objects8);
proposals.insert(proposals.end(), objects8.begin(), objects8.end());
generate_proposals(16, result_p16, prob_threshold, objects16);
proposals.insert(proposals.end(), objects16.begin(), objects16.end());
generate_proposals(32, result_p32, prob_threshold, objects32);
proposals.insert(proposals.end(), objects32.begin(), objects32.end());

std::vector<int> classIds;
std::vector<float> confidences;
std::vector<cv::Rect> boxes;

for (size_t i = 0; i < proposals.size(); i++)
{
classIds.push_back(proposals[i].label);
confidences.push_back(proposals[i].prob);
boxes.push_back(proposals[i].rect);
}

std::vector<int> picked;

// do non maximum suppression for each bounding boxx
cv::dnn::NMSBoxes(boxes, confidences, prob_threshold, nms_threshold, picked);

Moreover, we also need to further adjust the result data which corresponded to the input data of the model. It is restored to the original size image for display.

int idx = picked[i];
cv::Rect box = boxes[idx];
cv::Rect scaled_box = scale_box(box, padd);
drawPred(classIds[idx], confidences[idx], scaled_box, padd[2], raw_h, raw_w, src_img, class_names);

The Reference Example

This example provides reference implementations for C++ and Python respectively. The following are the exemplary usage methods:

1. Dependency Installation

# Download the sample repository
$ git clone https://github.com/OpenVINO-dev-contest/YOLOv7_OpenVINO_cpp-python.git
  • The C++ Environment Dependencies

Since the C++ version of this example relies only on the OpenVINO™ and OpenCV runtimes, the developer needs to install both tools in advance:

OpenVINO™
C++ runtime
https://docs.openvino.ai/latest/openvino_docs_install_guides_installing_openvino_linux.html#install-openvino
OpenCV https://docs.opencv.org/4.x/d7/d9f/tutorial_linux_install.html

Note: Since the CMakeList provided in this example uses the default path of OpenCV, you need to run the “make install” command after compiling OpenCV.

  • The Python Environment Dependencies
$ pip install -r python/requirements

The installation of the Python environment is relatively simple. You just need to install the dependency through the pip command-line tool.

Download the pre-trained model

The YOLOv7 pre-trained model weights based on the COCO dataset can be downloaded from the link provided by the official GitHub repository, https://github.com/WongKinYiu/yolov7

Model

Test Size

APtest

AP50test

AP75test

YOLOv7

640

51.4%

69.7%

55.9%

YOLOv7-X

640

53.1%

71.2%

57.8%

 

 

 

 

 

YOLOv7-W6

1280

54.9%

72.6%

60.1%

YOLOv7-E6

1280

56.0%

73.5%

61.2%

YOLOv7-D6

1280

56.6%

74.0%

61.8%

YOLOv7-E6E

1280

56.8%

74.4%

62.1%

 

1. Model Transformation

# Download YOLOv7 official repository:
$ git clone git@github.com:WongKinYiu/yolov7.git
$ cd yolov7/models
$ python export.py --weights yolov7.pt

At present, the OpenVINO™ Runtime can support onnx model directly. After obtaining the .pt weight file, we can export it as an onnx model by using the YOLOv7 official export. py script. The specific process is as follows:

2. Test Run

  • The C + + Example
 $ cd cpp
$ mkdir build && cd build
$ source '~/intel/openvino_2022.1.0.643/bin/setupvars.sh'
$ cmake ..
$ make

Compile the C++ sample source code, and the yolov7 executable file will be generated in the directory called build:

Perform Inference task:

 $ yolov7 yolov7.onnx data/horses.jpg 'CPU' 
  • The Python example

Perform Inference task:

 $ python python/main.py -m yolov7.onnx -i data/horse.jpg

3. Test Results

After running the inference example, an image with the bounding box and label will be generated in the local directory. Here we use the horse data attached to the official depository for testing. The specific results are as follows:

ai-blog-yolov7-fig03.jpg

Figure: The result of running  inference

Benchmark App Introduction

OpenVINO™ provides Benchmark App, a performance testing tool, which allows developers to quickly test the performance of OpenVINO™ models on different hardware platforms.

  $ benchmark_app -m yolov7.onnx -hint throughput

In the following example, we will briefly introduce how to use the Benchmark App and related parameters.

For more information, please refer to the official documentation of Benchmark App.

-m: Specify the path of model. Since the OpenVINO™ Runtime currently supports direct reading of ONNX format files. We set the exported ONNX model path as the model input.

-hint: Specifying a priority strategy for performance testing is to automatically select the relative parameters of underlying performance optimization. Here we choose throughput mode to improve the overall throughput of the system. If the application is sensitive to latency, it is recommended to use latency mode to reduce inference latency.

Conclusion

Due to its excellent accuracy and performance, YOLOv7 received great attention when it was first launched. Now the number of STAR on GitHub has exceeded 5K. Through the new C++ API interface of OpenVINO™ 2022.1, this example implements the deployment of the official pre-trained model of YOLOv7. Finally, we used OpenVINO™ benchmark_app to further validate the performance of the model.

References

The official depository of YOLOv7:

https://github.com/WongKinYiu/yolov7

The development documentation of OpenVINO™

https://docs.openvino.ai/latest/openvino_docs_OV_UG_Integrate_OV_with_your_application.html

Notices & Disclaimers

No product or component can be absolutely secure.

Intel does not control or audit third-party data. You should consult other sources to evaluate accuracy.
You may not use or facilitate the use of this document in connection with any infringement or other legal analysis concerning Intel products described herein. You agree to grant Intel a non-exclusive, royalty-free license to any patent claim thereafter drafted which includes subject matter disclosed herein.

No license (express or implied, by estoppel or otherwise) to any intellectual property rights is granted by this document.

The products described may contain design defects or errors known as errata which may cause the product to deviate from published specifications. Current characterized errata are available on request.
Intel technologies may require enabled hardware, software or service activation.

© Intel Corporation. Intel, the Intel logo, and other Intel marks are trademarks of Intel Corporation or its subsidiaries. Other names and brands may be claimed as the property of others.