Distillation (recommended 🚀) - LightlyTrain documentation (original) (raw)

View this page

Toggle table of contents sidebar

Knowledge distillation involves transferring knowledge from a large, compute-intensive teacher model to a smaller, efficient student model by encouraging similarity between the student and teacher representations. It addresses the challenge of bridging the gap between state-of-the-art large-scale vision models and smaller, more computationally efficient models suitable for practical applications.

Note

Starting from LightlyTrain 0.7.0, method="distillation" uses a new, improved v2 implementation that achieves higher accuracy and trains up to 3x faster. The previous version is still available viamethod="distillationv1" for backward compatibility.

Use Distillation in LightlyTrain

Python

import lightly_train

if name == "main": lightly_train.train( out="out/my_experiment", data="my_data_dir", model="torchvision/resnet18", method="distillation", )

Command Line

lightly-train train out=out/my_experiment data=my_data_dir model="torchvision/resnet18" method="distillation"

What’s under the Hood

Our distillation method directly applies a mean squared error (MSE) loss between the features of the student and teacher networks when processing the same image. We use a ViT-B/14 backbone from DINOv2 as the teacher model. Inspired by Knowledge Distillation: A Good Teacher is Patient and Consistent, we apply strong, identical augmentations to both teacher and student inputs to ensure consistency of the objective.

Lightly Recommendations

Default Method Arguments

The following are the default method arguments for distillation. To learn how you can override these settings, see Method Arguments.

{ "n_projection_layers": 1, "n_teacher_blocks": 2, "projection_hidden_dim": 2048, "teacher": "dinov2_vit/vitb14_pretrain" }

Default Image Transform Arguments

The following are the default transform arguments for distillation. To learn how you can override these settings, see Configuring Image Transforms.

{ "color_jitter": { "brightness": 0.8, "contrast": 0.8, "hue": 0.2, "prob": 0.8, "saturation": 0.4, "strength": 0.5 }, "gaussian_blur": { "blur_limit": 0, "prob": 1.0, "sigmas": [ 0.0, 0.1 ] }, "image_size": [ 224, 224 ], "normalize": { "mean": [ 0.485, 0.456, 0.406 ], "std": [ 0.229, 0.224, 0.225 ] }, "random_flip": { "horizontal_prob": 0.5, "vertical_prob": 0.0 }, "random_gray_scale": 0.2, "random_resize": { "max_scale": 1.0, "min_scale": 0.14 }, "random_rotation": null, "solarize": null }