Architecture Summary (original) (raw)

Ultralytics YOLOv5 Architecture

YOLOv5 (v6.0/6.1) is a powerful object detection algorithm developed by Ultralytics. This article dives deep into the YOLOv5 architecture, data augmentation strategies, training methodologies, and loss computation techniques. This comprehensive understanding will help improve your practical application of object detection in various fields, including surveillance, autonomous vehicles, and image recognition.

1. Model Structure

YOLOv5's architecture consists of three main parts:

The structure of the model is depicted in the image below. The model structure details can be found in yolov5l.yaml.

yolov5

YOLOv5 introduces some notable improvements compared to its predecessors:

  1. The Focus structure, found in earlier versions, is replaced with a 6x6 Conv2d structure. This change boosts efficiency #4825.
  2. The SPP structure is replaced with SPPF. This alteration more than doubles the speed of processing while maintaining the same output.

To test the speed of SPP and SPPF, the following code can be used:

SPP vs SPPF speed profiling example (click to open)

`` import time

import torch import torch.nn as nn

class SPP(nn.Module): def init(self): """Initializes an SPP module with three different sizes of max pooling layers.""" super().init() self.maxpool1 = nn.MaxPool2d(5, 1, padding=2) self.maxpool2 = nn.MaxPool2d(9, 1, padding=4) self.maxpool3 = nn.MaxPool2d(13, 1, padding=6)

def forward(self, x):
    """Applies three max pooling layers on input `x` and concatenates results along channel dimension."""
    o1 = self.maxpool1(x)
    o2 = self.maxpool2(x)
    o3 = self.maxpool3(x)
    return torch.cat([x, o1, o2, o3], dim=1)

class SPPF(nn.Module): def init(self): """Initializes an SPPF module with a specific configuration of MaxPool2d layer.""" super().init() self.maxpool = nn.MaxPool2d(5, 1, padding=2)

def forward(self, x):
    """Applies sequential max pooling and concatenates results with input tensor."""
    o1 = self.maxpool(x)
    o2 = self.maxpool(o1)
    o3 = self.maxpool(o2)
    return torch.cat([x, o1, o2, o3], dim=1)

def main(): """Compares outputs and performance of SPP and SPPF on a random tensor (8, 32, 16, 16).""" input_tensor = torch.rand(8, 32, 16, 16) spp = SPP() sppf = SPPF() output1 = spp(input_tensor) output2 = sppf(input_tensor)

print(torch.equal(output1, output2))

t_start = time.time()
for _ in range(100):
    spp(input_tensor)
print(f"SPP time: {time.time() - t_start}")

t_start = time.time()
for _ in range(100):
    sppf(input_tensor)
print(f"SPPF time: {time.time() - t_start}")

if name == "main": main() ``

result:

True SPP time: 0.5373051166534424 SPPF time: 0.20780706405639648

2. Data Augmentation Techniques

YOLOv5 employs various data augmentation techniques to improve the model's ability to generalize and reduce overfitting. These techniques include:

3. Training Strategies

YOLOv5 applies several sophisticated training strategies to enhance the model's performance. They include:

4. Additional Features

4.1 Compute Losses

The loss in YOLOv5 is computed as a combination of three individual loss components:

The overall loss function is depicted by:

loss

4.2 Balance Losses

The objectness losses of the three prediction layers (P3, P4, P5) are weighted differently. The balance weights are [4.0, 1.0, 0.4] respectively. This approach ensures that the predictions at different scales contribute appropriately to the total loss.

obj_loss

4.3 Eliminate Grid Sensitivity

The YOLOv5 architecture makes some important changes to the box prediction strategy compared to earlier versions of YOLO. In YOLOv2 and YOLOv3, the box coordinates were directly predicted using the activation of the last layer.

b_x b_y b_w b_h

YOLOv5 grid computation

However, in YOLOv5, the formula for predicting the box coordinates has been updated to reduce grid sensitivity and prevent the model from predicting unbounded box dimensions.

The revised formulas for calculating the predicted bounding box are as follows:

bx by bw bh

Compare the center point offset before and after scaling. The center point offset range is adjusted from (0, 1) to (-0.5, 1.5). Therefore, offset can easily get 0 or 1.

YOLOv5 grid scaling

Compare the height and width scaling ratio (relative to anchor) before and after adjustment. The original yolo/darknet box equations have a serious flaw. Width and Height are completely unbounded as they are simply out=exp(in), which is dangerous, as it can lead to runaway gradients, instabilities, NaN losses and ultimately a complete loss of training. Refer to this issue for more details.

YOLOv5 unbounded scaling

4.4 Build Targets

The build target process in YOLOv5 is critical for training efficiency and model accuracy. It involves assigning ground truth boxes to the appropriate grid cells in the output map and matching them with the appropriate anchor boxes.

This process follows these steps:

rw

rh

rwmax

rhmax

rmax

match

YOLOv5 IoU computation

YOLOv5 grid overlap

YOLOv5 anchor selection

This way, the build targets process ensures that each ground truth object is properly assigned and matched during the training process, allowing YOLOv5 to learn the task of object detection more effectively.

Conclusion

In conclusion, YOLOv5 represents a significant step forward in the development of real-time object detection models. By incorporating various new features, enhancements, and training strategies, it surpasses previous versions of the YOLO family in performance and efficiency.

The primary enhancements in YOLOv5 include the use of a dynamic architecture, an extensive range of data augmentation techniques, innovative training strategies, as well as important adjustments in computing losses and the process of building targets. All these innovations significantly improve the accuracy and efficiency of object detection while retaining a high degree of speed, which is the trademark of YOLO models.

📅 Created 1 year ago ✏️ Updated 3 months ago

glenn-jocher RizwanMunawar ambitious-octopus sergiuwaxmann