YOLOV7 Inference Acceleration With Structural Pruning

When the neural model has been trained, the next step is to prepare it for use: inference is a step where the model is expected to shine and demonstrate its efficiency and effectiveness.
Frequently, in addition to quality, speed is an important parameter: in many products, the inability to perform inference economically is a no-go to commercial use. In addition, the model runtime environment may also have limitations on the amount of available RAM, so improvement may also be needed from the perspective of memory consumption.
An efficient model is optimized for memory usage and processing performance without sacrificing quality.
Methods to Improve Neural Model Efficiency
The main ways to increase neural model efficiency are the following:
- Optimized libraries or frameworks (usually from hardware vendors) to speed up operations and computation graphs. The frameworks like OpenVino (Intel), TensorRT (Nvidia), and Vitis AI (Xilinx/AMD) are representatives of such an approach.
- Knowledge distillation: transferring knowledge from a large, complex model to a smaller one with a simpler architecture and fewer parameters. Using this approach, we train a “small” model using a “large” model.
- Quantization: the idea is to convert model weights into small integer types (int8, int4). The fact is that modern specialized central and graphical processors can work faster with integer values; hence performance increases and less memory is used.
- Pruning: a method of removing some of the parameters and connections of a neural network. This approach removes the least significant weights and connections from the trained model. The removal can be done either by individual connections or by entire neurons.
The first method is widely used by many developers. The second requires developing effective student and teacher models compatible with each other, which is a complex problem.
Quantization and pruning are relatively well-developed areas for which ready-made tools exist. Often, both approaches are used simultaneously, but we discuss only pruning for this article.
Pruning Approaches
As early as 1989, the idea of pruning appeared, which Yann Lecun proposed in his work “Optimal Brain Damage.” The essence of this idea is that neural networks include redundant parameters that do not significantly affect the result.
There are two widely known approaches to pruning: unstructured and structured.
Unstructured Pruning
The approach is based on zeroing low-affecting parameters. This approach has advantages: first, it is simple in the implementation because it is sufficient to set the required parameter’s weight to zero.
E.g. the widespread deep learning framework Pytorch provides the corresponding functions.
Another advantage of this approach is preciseness — only those parameters that are really close to zero are set to zero, having minimal impact on the inference quality of the model. And since such pruning affects individual links that are not known in advance, this approach is called unstructured pruning.
However, this method has a drawback: although the weights are zeroed out, the numbers for arithmetic operations and weights remain intact.
Such pruning results in sparse matrices, but most frameworks and hardware cannot efficiently accelerate the calculation on sparse matrices. Thus, no matter how many parameters are zeroed in the parameter tensor, it does not affect the computation speed.
Nevertheless, hardware vendors are improving computations on sparse matrices: Nvidia’s Ampere family and Tensor RT 8.9 get the advantage of such a representation.
However, Nvidia tools require a specific sparsity pattern (NVIDIA Sparsity), which introduces particular complexity to the pruning process (although there are code repositories from Nvidia to simplify obtaining such a model) and reduces the benefits of using this approach. Also, according to the above-mentioned article, the performance gain is insignificant.
Structured Pruning
Due to the significant disadvantages of Unstructured pruning, many studies are devoted to pruning larger structures, such as whole neurons or their direct counterpart in modern deep convolutional neural networks — convolution filters.
Filters removal is an operation on relatively small elements of neural network architecture since modern models usually include many convolution layers, each with hundreds or thousands of them.
Obviously, networks with fewer convolution filters take up less RAM due to fewer parameters, require less computation, and generate simpler intermediate representations. For these reasons, filter removal is now seen as a priority approach for structured pruning.
The disadvantage of this approach is that one cannot simply remove a filter: since modern neural model architectures have a very complex structure, removing a filter in one of the layers entails changes in the structure of the overlying and underlying layers. When implemented incorrectly, it can lead to the removal of significant connections or the complete dysfunction of the entire model.
The important question of structured pruning is the choice of an approach evaluating which filter to remove. We no longer evaluate a single parameter, as in the unstructured pruning, but a whole group of parameters.
The following image demonstrates the differences between the two approaches using a single convolutional layer as an example.

Structured Pruning Using the YOLOV7 OBB Model as an Example
The repository with the code for the pruning procedure is located on GitHub at: https://github.com/insight-platform/Yolo_V7_OBB_Pruning. If you want to reproduce all the steps described below, first perform all the steps in the Getting Started section.
The YOLO family is universally used in practice because of its accuracy and performance. For pruning, we chose the YOLOV7 model tailored for detecting rotated boxes.
For YOLO models, there are modifications for the detection of objects using oriented boxes, which is often necessary for practical tasks when selecting an object with a minimum background. As an example, the task of person detection on ultra-wide-angle cameras placed on the ceiling (fisheye 360) is chosen.
A combined dataset was compiled to train the model. This dataset includes datasets CEPDOF, HABBOF, MW-R, WEPDTOF, and images from the COCO dataset on which there are four or more persons.
For exploration and preparation of this dataset and independent training of the YOLOV7 model, you can open JupterLab and use the notebook “prepare_datasets.ipynb” located in the directory “notebook.” For experiments and training, you can create separate directories for each experiment in the “experiments” directory. In the repository, you will find all experiments that have been performed.
If you want to retrain a model or experiment with training the original model, you need to create (or copy) the experiment directory and create (or fix copied) files:
data.yaml
: a file that specifies the path to datasets, the number of classes, and class labels;hyp.yaml
: configuration of hyperparameters for training;run.sh
: script to start training;yolov7.yaml
: model architecture.
To train, connect to the docker container (follow all the steps in the Getting Started section first before running further commands) and start the process for the selected experiment (instead of “fisheye_person_v1.0.0", you need to specify the directory with the experiment):
docker exec -it yolo_v7_obb_pruning_yolov7obb_1 /bin/bash
bash ./experiments/fisheye_person_v1.0.0/run.sh
Training results and weights files are in the “runs/train” directory. A separate directory is created for each training run to store the results.
You can track and compare learning quality metrics by opening Tensorboard (http://127.0.0.1:6006) in a browser.
The mAP metric (link) was used to evaluate the models. The initial YOLOV7 model has mAP@0.5 equal to 0.7898 and mAP@0.5:0.95 equal to 0.4119.
Pruning Process
We used the torch-pruning library (https://github.com/VainF/Torch-Pruning) for the task. It supports various pruning algorithms, can build graphs of dependencies, and has functions to estimate the significance of convolution filters.
At the time of the experiments, the method, showing promising results and implemented in the library, was an approach based on estimating which Batch Normalization layers (BN) have a minor impact on the result (in recent library updates, new approaches have been added).
Many neural networks use BN layers with scaling and channel shift parameters after convolutional layers. Therefore, it is possible to directly use the scaling parameters in BN layers to estimate the importance of specific channels.

Figure 2, on the left, shows the original network. After training, when the scaling multiplier for a channel is close to 0, the channel has minimal effect on all subsequent layers and, consequently, on the final result.
This allows for removing such channels and thus reduces the number of parameters in the network. The pruned network is shown in Figure 2 on the right. You can read more about the approach in the article https://arxiv.org/pdf/1708.06519.pdf.
You must train the original model with an additional regularization parameter for the BN layers to use the approach. It is implemented in the training script (link) by specifying the -- sparsity
parameter.
After the model is trained, we apply the pruning algorithm described above. To do this, we exclude the last layers of the model by adding them to the list and passing to the pruner. This is done to keep the dimensionality of the model output without change:
exclude = (RepConv, IDetect)
exclude_idx = [102, 103, 104, 105]
for i, layer in enumerate(model.model):
if isinstance(layer, exclude) or i in exclude_idx:
ignored_layers.append(layer)
Next, we specify a method for filter importance estimation. We use the approach based on the scaling multiplier of the BN layer:
imp = tp.importance.BNScaleImportance()
After that, we initialize a class that implements the pruning algorithm by passing the model, the test input data vector, the importance evaluation function, and other parameters:
pruner = tp.pruner.BNScalePruner(
model=model,
example_inputs=example_inputs,
importance=imp,
iterative_steps=400,
ch_sparsity=1.0,
ch_sparsity_dict=ch_sparsity_dict,
max_ch_sparsity=1,
ignored_layers=ignored_layers,
unwrapped_parameters=unwrapped_parameters,
global_pruning=False,
model_ref=initial_model,
round_to=round_to
)
The last step is to iteratively remove filters until the desired reduction of the number of operations is reached:
base_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
current_param_reduce = 1
iter = 0
while current_param_reduce < operation_reduce:
iter += 1
pruner.step()
pruned_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
current_param_reduce = float(base_ops) / pruned_ops
print(f"iter {iter}: base operations = {base_ops}, operation after pruning={pruned_ops}, decreased operations in={current_param_reduce}")
The function “count_ops_and_params” allows estimating the theoretically required number of arithmetic operations.
You can find the full script in “yolov7obb/pruning.py” (GitHub).
The script takes two arguments as parameters:
- How many times do we want to reduce the number of operations (parameter
-o
); - how many channels at once to remove in convolutional layers (parameter
-r
); we set this parameter to32
to make the layers a multiple of32
.
In addition, the model weights obtained during the training (--weights
) and the weights from which the training was started (--weights-ref
) must be provided:
python ./yolov7obb/pruning.py \
--weights ./runs/train/fisheye_person_v1.0.0/weights/best_145.pt \
--weights-ref ./runs/train/fisheye_person_v1.0.0/weights/init.pt \
-o 1.2 -r 32
Results of Retraining for The Pruned Models
After pruning, it is necessary to fine-tune the model or train it again from scratch. We conducted three experiments with different start weights for fine-tuning and training.
In the first case, we took the weights of the initial model from the epoch with the best metrics as the starting weights and performed fine-tuning.
In the second case, the initial model weights from epoch 88 were taken as starting weights, and fine-tuning was carried out. This epoch relates to the middle of the training cycle.
And in the third case, we trained the pruned model from scratch, starting with the same weights as the original model.
The figures below show the plots of the detection quality for mAP@0.5 and mAP@0.5:0.95 metrics on the validation dataset.
Our experiments demonstrated that the best results are achieved when the pruned model starts training anew from the initial weights.
- the dark-gray graph represents the initial model (the best results on epoch 145);
- the yellow graph demonstrates the training of the pruned model from the initial weights;
- the purple plot — the pruned model fine-tuned from epoch 88;
- the green graph — the pruned model fine-tuned from the best weights of the initial model.


Given the above, we conducted all further experiments with the pruned model trained from the initial weights.
Performance Calculation Procedure
We convert the models into ONNX format to evaluate their performance with TensorRT. To export to ONNX, use the following command inside the docker container:
python ./yolov7obb/export.py \
--weights /opt/app/runs/train/fisheye_person_v1.0.0/weights/best_145.pt \
--img-size 640 640 \
--batch-size 1 \
--onnx \
--grid \
--end2end \
--simplify \
--fp16
The parameters of the export script have the following meaning:
--img-size
defines the input image dimensionality for the model; in the case of YOLOV7OBB, it is not required to match the dimensionality of the inference image with that of the training images, but the best quality results are achieved when they coincide;--batch-size
sets the model batching size;--onnx
sets the export format;--grid
enables combining results from different scales;--end2end
enables all calculations necessary to obtain the final result in the graph;--simplify
enables the use of the ONNX-simplifier library to simplify the calculation graph;--fp16
defines conversion and creation of ONNX with half precision.
After converting the model to ONNX format, we can generate a TensorRT- optimized model and simultaneously measure its inference speed. To do this, from the root directory of the project, run the container with the TensorRT environment:
docker run -it --rm --gpus device=1 \
-v `pwd`/runs:/models nvcr.io/nvidia/tensorrt:22.12-py3
To generate an optimized model and measure its speed on the example of a fully trained YOLOV7OBB model, use the following command:
trtexec --onnx=/models/train/fisheye_person_v1.0.0/weights/best_145.onnx \
--saveEngine=ref_model_best_145.engine \
--fp16 \
--iterations=50 \
--duration=20 \
--avgRuns=20 \
--warmUp=10
After executing the command, we find the QPS (query per second) parameter in the output: the number of queries the model could handle in 1 second. If the model handles video, one query corresponds to one frame, so QPS equals FPS.
Performance Numbers
We pruned the model with different values of the operation-reduce (-o
) parameter, which is responsible for the theoretical reduction of the number of operations during model inference. Obtained models were trained from the initial weights from which the original model was trained.
The training parameters (learning rate, batch size, etc.) were identical in all experiments. Since the pruning process is iterative and removing filters from the model is performed in groups of 32 filters, the actual reduction in the number of operations may slightly differ from the target value.
The X-axis shows how the theoretically calculated ratio of the number of operations in the original and pruned models, i.e., the bars labeled “1” correspond to the original model.
Let us first estimate the decline of the detection quality after pruning. Figures 5 and 6 show the metrics mAP@0.5 and mAP@0.5:0.95, respectively. Blue features the metric value in absolute units, and red represents the difference between the values of the original and pruned models in percent. The percentages show how much the metric values have decreased.


Figure 7 shows the performance measurements. Blue is absolute FPS values; red is the increase in the number of FPS relative to the original model.

Figure 8 shows a comparison of changes in metrics and model performance. Blue and red show the decrease in metrics: the more, the worse the detections. The situation with performance (yellow) is the opposite, the higher, the better the performance.

Analyzing all the results, we can conclude that the best result is achieved with a reduction of 1.61 times: quality metrics decreased by 5 percent, and computational acceleration increased by 22 percent relative to the original model.
In the videos below, you can visually evaluate the difference in detection quality of different models. Each scene demonstrates 3 videos: the first shows the results of the original model, the second shows the model with a reduction of 1.61, and the third shows the model with a reduction of 2.27.
Conclusions
The demonstrated pruning procedure accelerates the model without critically sacrificing the quality. Our experiments demonstrated that depending on the problem and the final quality goal, the optimal ratio of quality and acceleration can be chosen.
The pruned model is best trained from the initial weights of the original model. The ready-made libraries like Torch-Pruning, which we used, provide handy tools to run pruning without extra hassle.
Authorship Notes
The experiments were planned and executed by our Bogoslovskii. The editorial work is done by me.
If you like our materials, check out our new Pythonic computer vision and video analytics framework, Savant, where we integrate our innovations and best practices related to CV.