Testers - PyTorch Metric Learning (original) (raw)

Testers take your model and dataset, and compute nearest-neighbor based accuracy metrics. Note that the testers require the faiss package, which you can install with conda.

In general, testers are used as follows:

`from pytorch_metric_learning import testers t = testers.SomeTestingFunction(*args, **kwargs) dataset_dict = {"train": train_dataset, "val": val_dataset} all_accuracies = tester.test(dataset_dict, epoch, model)

Or if your model is composed of a trunk + embedder

all_accuracies = tester.test(dataset_dict, epoch, trunk, embedder) `

You can perform custom actions by writing an end-of-testing hook (see the documentation for BaseTester), and you can access the test results directly via the all_accuracies attribute:

def end_of_testing_hook(tester): print(tester.all_accuracies)

This will print out a dictionary of accuracy metrics, per dataset split. You'll see something like this:

{"train": {"AMI_level0": 0.53, ...}, "val": {"AMI_level0": 0.44, ...}}

Each of the accuracy metric names is appended with level0, which refers to the 0th label hierarchy level (see the documentation for BaseTester). This is only relevant if you're dealing with multi-label datasets.

For an explanation of the default accuracy metrics, see the AccuracyCalculator documentation.

Testing splits

By default, every dataset in dataset_dict will be evaluated using itself as the query and reference (on which to find nearest neighbors). More flexibility is allowed with the optional argument splits_to_eval taken by tester.test().splits_to_eval is a list of (query_split, [list_of_reference_splits]) tuples.

For example, let's say your dataset_dict has two keys: "dataset_a" and "train".

BaseTester

All trainers extend this class and therefore inherit its __init__ arguments.

testers.BaseTester(normalize_embeddings=True, use_trunk_output=False, batch_size=32, dataloader_num_workers=2, pca=None, data_device=None, dtype=None, data_and_label_getter=None, label_hierarchy_level=0, end_of_testing_hook=None, dataset_labels=None, set_min_label_to_zero=False, accuracy_calculator=None, visualizer=None, visualizer_hook=None,)

Parameters:

Functions:

Call this to test your model on a dataset dict. It returns a dictionary of accuracies.

all_accuracies = tester.test( dataset_dict, # dictionary mapping strings to datasets epoch, # used for logging trunk_model, # your model embedder_model=None, # by default this will be a no-op splits_to_eval=None, collate_fn=None # custom collate_fn for the dataloader )

Returns all the embeddings and labels for the input dataset and model.

embeddings, labels = tester.get_all_embeddings( dataset, # Any pytorch dataset trunk_model, # your model embedder_model=None, # by default this will be a no-op collate_fn=None, # custom collate_fn for the dataloader eval=True, # set models to eval mode return_as_numpy=False )

GlobalEmbeddingSpaceTester

Computes nearest neighbors by looking at all points in the embedding space (rather than a subset). This is probably the tester you are looking for. To see it in action, check one of the example notebooks

testers.GlobalEmbeddingSpaceTester(*args, **kwargs)

WithSameParentLabelTester

This assumes there is a label hierarchy. For each sample, the search space is narrowed by only looking at sibling samples, i.e. samples with the same parent label. For example, consider a dataset with 4 fine-grained classes {cat, dog, car, truck}, and 2 coarse-grained classes {animal, vehicle}. The nearest neighbor search for cats and dogs will consist of animals, and the nearest-neighbor search for cars and trucks will consist of vehicles.

testers.WithSameParentLabelTester(*args, **kwargs)

GlobalTwoStreamEmbeddingSpaceTester

This is the corresponding tester for TwoStreamMetricLoss. The supplied dataset must return (anchor, positive, label).

testers.GlobalTwoStreamEmbeddingSpaceTester(*args, **kwargs)

Requirements:

This tester only supports the default value for splits_to_eval: each split is used for both query and reference