Artificial Intelligence (AI)
Discuss current events in AI and technological innovations with Intel® employees
491 Discussions

YOLOv5 Model INT8 Quantization based on OpenVINO™ 2022.1 POT API

Stephanie_Maluso
Employee
1 0 7,569

Author: Xiake Sun, AI Frameworks Engineer, OpenVINOTM Developer Tools, Intel

1. Overview

Ultralytics YOLOv5, one of the most popular object detection networks, is popular among AI developers and widely used in industrial applications because of its great engineering and documentation support. In this article, we will introduce how to use OpenVINOTM 2022.1 Post-training Optimization Tool (POT) API for YOLOv5 Model INT8 quantization, to achieve model compression and inference performance improvement. In addition, we provide the FP32 and INT8 model accuracy calculation methods, introduce OpenVINO Benchmark App for performance evaluation, and show the YOLOv5 INT8 model object detection demo with OpenVINO backend.

Please refer to OpenVINO notebook for complete source code: 220-yolov5-accuracy-check-and-quantization。

2. Introduction to POT Tool

For the pre-trained Deep-Learning models, we can use the Post Training Optimization Tool (POT) in OpenVINO to map the model weights and activation from FP32/FP16 value domain to the INT8 value domain. After model INT8 quantization, we can reduce the computational resources and memory bandwidth required for model inference to help improve the model's overall performance. Unlike Quantization-aware Training (QAT) method, no re-train, or even fine-tuning is needed for POT optimization to obtain INT8 models with great accuracy. Therefore, POT is widely used as a best practice for quantization.

Fig.1 shows the OpenVINO optimization workflow with POT, including the following elements:

  • A floating-point precision model, FP32 or FP16, converted into the OpenVINO Intermediate Representation (IR) format and run on CPU with the OpenVINO.
  • A representative calibration dataset representing a use case scenario, for example, 300 samples.
  • In case of accuracy constraints, a validation dataset and accuracy metrics should be available.

Fig.1 OpenVINO optimization workflow with POTFig.1 OpenVINO optimization workflow with POT

2.1 POT Quantization Algorithms

POT provides the following two main quantization algorithms:

  1. Default Quantization (DQ) provides a fast quantization method to obtain the quantized model with great accuracy in most cases. DQ is suitable as a baseline for model INT8 quantization.
  2. Accuracy-aware Quantization (AAQ) is an iterative quantization algorithm based on Default Quantization. The model quantified by DQ is used as the baseline. If the baseline model accuracy does not reach the predefined accuracy range, the AAQ will fall back to the layer with the greatest impact on the accuracy from INT8 precision to FP32 precision. It will then re-evaluate the model accuracy and repeat the process until the model reaches the expected accuracy range.

2.2. How to Run POT

POT provides the following two methods to run - the POT command-line interface and API method:

  1. Command-line interface: run command-line with a configuration file to utilize OpenVINO Accuracy Checker Tool predefined DataLoader, Metric, Adapter, and Pre/Postprocessing modules. This method is used for INT8 quantization of OpenVINO Open Model Zoo supported models or similar models.
  2. API method: provides a full set of interfaces and helpers, including base classes such as DataLoader and Metric that allow users to implement a custom optimization pipeline for various types of DL models. This approach is more flexible and can be applied to the different customized model’s quantization pipeline. Fig.2 shows the general workflow of INT8 quantization based on POT API.

Fig.2 General workflow of INT8 quantization based on POT APIFig.2 General workflow of INT8 quantization based on POT API

YOLOv5 uses custom pre- and post-processing modules, including letterbox, Non-maximum Suppression (NMS), etc., which are different from predefined pre- and post-processing modules of the OpenVINO Accuracy Checker Tool. Therefore, we choose to implement a customized YOLOv5 INT8 quantization pipeline with custom DataLoader and Metric class based on POT API.

 

3. YOLOv5 INT8 Quantization Based on POT API

3.1. Setup YOLOv5 and OpenVINO Development Environment

First, download the YOLOv5 source code, and install YOLOv5 and OpenVINO Python dependencies.

git clone https://github.com/ultralytics/yolov5.git -b v6.1 
cd yolov5 && pip install -r requirements.txt && pip install \ openvino==2022.1.0 openvino-dev==2022.1.0

Then, convert the pre-trained PyTorch model to the OpenVINO FP32 IR model via export.py provided by YOLOv5. Fig.3 shows the output results of the model conversion process.

python export.py --weights yolov5m/yolov5m.pt --imgsz 640 \
--batch-size 1 --include openvino

Fig.3 PyTorch model to OpenVINO FP32 IR model conversion outputFig.3 PyTorch model to OpenVINO FP32 IR model conversion output

We implement the customized quantization pipeline based on the POT API in the following steps:

  • Create YOLOv5 DataLoader Class: Define data and annotation loading and pre-processing
  • Create COCOMetric Class: Define the model post-processing and accuracy calculation method
  • Set the quantization algorithm and related parameters, define, and run the quantization pipeline

3.2. Create YOLOv5DataLoader Class

First, we define a custom YOLOv5Dataloader class by inheriting the POT DataLoader base class. We have excerpted some code below, where the _init_dataloader(self) function calls the YOLOv5 create_dataloader() function to read the dataset and change the input image size via the letterbox. In addition, each call to __getitem__(self, item) will read the input image with the index item and the corresponding annotation, normalize the image, and finally return the item, annotation, and the preprocessed image.

class YOLOv5DataLoader(DataLoader):
""" Inherit from DataLoader function and implement for YOLOv5.
"""
def _init_dataloader(self):
dataloader = create_dataloader(self._data_source['val'],
imgsz=self._imgsz, batch_size=self._batch_size,
stride=self._stride, single_cls=self._single_cls,
pad=self._pad, rect=self._rect, workers=self._workers)[0]
return dataloader

def __getitem__(self, item):
try:
batch_data = next(self._data_iter)
except StopIteration:\
self._data_iter = iter(self._data_loader)
batch_data = next(self._data_iter)
im, target, path, shape = batch_data
im = im.float()
im /= 255
nb, _, height, width = im.shape
img = im.cpu().detach().numpy()
target = target.cpu().detach().numpy()
annotation = dict()
annotation['image_path'] = path
annotation['target'] = target
annotation['batch_size'] = nb
annotation['shape'] = shape
annotation['width'] = width
annotation['height'] = height
annotation['img'] = img
return (item, annotation), img

3.3. Create COCOMetric Class

Besides, we create a COCOMetric class by inheriting the POT Metric base class, which integrates YOLOv5's native post-processing NMS function and mAP metric calculation method based on the COCO dataset format. The update(self, output, target) function has two inputs: output and target, which are the raw output of the model inference result and the annotation of the input image. The model raw output is post-processed by YOLOv5 NMS and then calculated with annotation to get the object detection accuracy statistics of the input image. Finally, _process_stats(self, stats) takes the accuracy statistics stats of all images as input to calculate model accuracy AP@0.5 and AP@0.5:0.95.

class COCOMetric(Metric):
""" Inherit from DataLoader function and implement for YOLOv5.
"""
def _process_stats(self, stats):
mp, mr, map50, map = 0.0, 0.0, 0.0, 0.0
stats = [np.concatenate(x, 0) for x in zip(*stats)]
if len(stats) and stats[0].any():
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=False, save_dir=None, names=self._class_names)
ap50, ap = ap[:, 0], ap.mean(1)
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
np.bincount(stats[3].astype(np.int64), minlength=self._nc)
else:
torch.zeros(1)

return mp, mr, map50, map

def update(self, output, target):
""" Calculates and updates metric value
Contains postprocessing part from Ultralytics YOLOv5 project
:param output: model output
:param target: annotations
"""

annotation = target[0]["target"]
width = target[0]["width"]
height = target[0]["height"]
shapes = target[0]["shape"]
paths = target[0]["image_path"]
im = target[0]["img"]

iouv = torch.linspace(0.5, 0.95, 10).to(self._device) # iou vector for mAP@0.5:0.95
niou = iouv.numel()
seen = 0
stats = []
# NMS
annotation = torch.Tensor(annotation)
annotation[:, 2:] *= torch.Tensor([width, height, width, height]).to(self._device) # to pixels
lb = []
out = output[0]
out = torch.Tensor(out).to(self._device)
out = non_max_suppression(out, self._conf_thres, self._iou_thres, labels=lb,
multi_label=True, agnostic=self._single_cls)
# Metrics
for si, pred in enumerate(out):
labels = annotation[annotation[:, 0] == si, 1:]
nl = len(labels)
tcls = labels[:, 0].tolist() if nl else [] # target class
_, shape = Path(paths[si]), shapes[si][0]
seen += 1

if len(pred) == 0:
if nl:
stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
continue

# Predictions
if self._single_cls:
pred[:, 5] = 0
predn = pred.clone()
scale_coords(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred

# Evaluate
if nl:
tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
scale_coords(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
correct = process_batch(predn, labelsn, iouv)
else:
correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool)
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
self._stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
self._last_stats = stats

3.4. Setting POT Quantization Parameters

After creating YOLOv5DataLoader and COCOMetric class, we can use get_config() to set the parameters of the model, engine, dataset, metric, and algorithms in the POT quantization pipeline.
Here we excerpt the algorithms config section, and here we choose the "DefaultQuantization" quantization algorithm to get one of the best performances of the model quickly. In addition, for the YOLOv5 model with non-ReLU activation, we set "preset": "mixed" to quantize the model weights symmetrically and quantize the model activation asymmetrically for better accuracy of the quantized model.

def get_config():
""" Set the configuration of the model, engine,
dataset, metric and quantization algorithm.
"""
algorithms = [
{
"name": "DefaultQuantization", # or AccuracyAwareQuantization
"params": {
"target_device": "CPU",
"preset": "mixed",
"stat_subset_size": 300
}
}
]
config["algorithms"] = algorithms

return config

3.5. Define and Run the Quantization Pipeline

In this section, we demonstrate step by step how to define and run the quantization pipeline with the following code, which includes model loading, initialization of DataLoader, Metric, Engine and Pipeline required for quantization, FP32 model accuracy verification, INT8 model quantization, and INT8 model accuracy verification. Finally, the quantized OpenVINO INT8 IR model is saved locally.

""" Download dataset and set config
"""
config = get_config()
init_logger(level='INFO')
logger = get_logger(__name__)
save_dir = increment_path(Path("./yolov5/yolov5m/yolov5m_openvino_model/"), exist_ok=True) # increment run
save_dir.mkdir(parents=True, exist_ok=True) # make dir

# Step 1: Load the model.
model = load_model(config["model"])

# Step 2: Initialize the data loader.
data_loader = YOLOv5DataLoader(config["dataset"])

# Step 3 (Optional. Required for AccuracyAwareQuantization): Initialize the metric.
metric = COCOMetric(config["metric"])

# Step 4: Initialize the engine for metric calculation and statistics collection.
engine = IEEngine(config=config["engine"], data_loader=data_loader, metric=metric)

# Step 5: Create a pipeline of compression algorithms.
pipeline = create_pipeline(config["algorithms"], engine)

metric_results = None

# Check the FP32 model accuracy.
metric_results_fp32 = pipeline.evaluate(model)

logger.info("FP32 model metric_results: {}".format(metric_results_fp32))

# Step 6: Execute the pipeline to calculate Min-Max value
compressed_model = pipeline.run(model)

# Step 7 (Optional): Compress model weights to quantized precision
# in order to reduce the size of final .bin file.
compress_model_weights(compressed_model)

# Step 8: Save the compressed model to the desired path.
optimized_save_dir = Path(save_dir).joinpath("optimized")
save_model(compressed_model, Path(Path.cwd()).joinpath(optimized_save_dir), config["model"]["model_name"])

# Step 9 (Optional): Evaluate the compressed model. Print the results.
metric_results_i8 = pipeline.evaluate(compressed_model)

logger.info("Save quantized model in {}".format(optimized_save_dir))
logger.info("Quantized INT8 model metric_results:
{}".format(metric_results_i8))

3.6. YOLOv5m FP32 and INT8 Model Accuracy Comparison

Fig.4 shows the YOLOv5m FP32 and INT8 models' accuracy results. Compared with the FP32 model, the accuracy drop of the INT8 model quantized with the "DefaultQuantization" algorithm is controlled within 1%, which is generally acceptable for the object detection application. If higher accuracy for the INT8 model is required, developer could try the "AccuracyAwareQuantization" algorithm to improve the model accuracy iteratively.

Fig.4 YOLOv5m FP32 and INT8 Model Accuracy ComparisonFig.4 YOLOv5m FP32 and INT8 Model Accuracy Comparison

 

3.7. OpenVINO Benchmark App Introduction

OpenVINO provides a performance evaluation tool - Benchmark App, for developers to quickly evaluate the performance of OpenVINO models on different hardware platforms. Here we briefly introduce Benchmark App usage and related parameters. For more Benchmark App parameters, please refer to the official documentation of Benchmark App.

benchmark_app -m \
./yolov5/yolov5m/yolov5m_openvino_model/optimized/yolov5m.xml \
-i ./yolov5/data/images/bus.jpg -d CPU -hint throughput

OpenVINO Benchmark App is provided in both Python and C++. We use the Python version of Benchmark App installed via openvino-dev python package. We use the following parameters for performance evaluation:

  • -m: Path to the OpenVINO model .xml file. Here we set it to the path of the quantized YOLOv5 INT8 model.
  •  -i: Path to input data file/folder for performance evaluation. Here we choose the bus.jpg image as input, if -i is not given, benchmark_app will automatically generate random data corresponding to the model input shape as input.
  •  -d: Set target device for performance evaluation. Here we choose CPU as the target device for inference, and developers may choose other target hardware supported by OpenVINO. If -d is not given, the CPU will be selected as the target hardware by default.
  •  -hint: Set the performance hint for the performance priority policy, to automatically select proper parameters for performance optimization. Here we choose throughput mode to improve the overall system throughput. If the application is more latency-sensitive, it is recommended to use latency mode to reduce inference latency.

According to their hardware platform and use case scenarios, developers could choose appropriate parameters for performance evaluations on YOLOv5 FP32 and INT8 models.

3.8. YOLOv5 INT8 Model Inference Demo

Finally, we run the YOLOv5m INT8 model inference demo with OpenVINO backend using the following command line:

cd yolov5 && python detect.py \
--weights ./yolov5m/yolov5m_openvino_model/optimized/yolov5m.xml

Fig.5 shows the input image and the inference results of the INT8 model. We can see that the INT8 model detects all the cars and pedestrians' bounding box and label in the image with high confidence.

Fig.5 Input image (left) and object detection result (right) of the YOLOv5 INT8 model inference demo with OpenVINO backendFig.5 Input image (left) and object detection result (right) of the YOLOv5 INT8 model inference demo with OpenVINO backend

4. Summary

In this article, we first convert the pre-trained YOLOv5m PyTorch model into an OpenVINO FP32 IR model. Besides, we define customized DataLoader and Metric class to reuse YOLOv5 custom pre and post-processing (letterbox, Non-maximum Suppression) and accuracy calculation modules based on OpenVINO POT API. Then, the "DefaultQuantization" algorithm is used to run the quantization pipeline for INT8 quantization. We found that the INT8 model quantized by the "DefaultQuantization" algorithm has great accuracy (AP@0.5, AP@0.5:0.95 accuracy drop within 1%) compared with the accuracy of the FP32 model. Furthermore, we introduce the OpenVINO benchmark evaluation tool - Benchmark App. Finally, we demonstrated YOLOv5m INT8 model inference with OpenVINO backend demo.

5. Reference

Notices & Disclaimers

No product or component can be absolutely secure.

Intel does not control or audit third-party data. You should consult other sources to evaluate accuracy.
You may not use or facilitate the use of this document in connection with any infringement or other legal analysis concerning Intel products described herein. You agree to grant Intel a non-exclusive, royalty-free license to any patent claim thereafter drafted which includes subject matter disclosed herein.

No license (express or implied, by estoppel or otherwise) to any intellectual property rights is granted by this document.

The products described may contain design defects or errors known as errata which may cause the product to deviate from published specifications. Current characterized errata are available on request.
Intel technologies may require enabled hardware, software, or service activation.

© Intel Corporation. Intel, the Intel logo, and other Intel marks are trademarks of Intel Corporation or its subsidiaries. Other names and brands may be claimed as the property of others.