Inference Models - PyTorch Metric Learning (original) (raw)

utils.inference contains classes that make it convenient to find matching pairs within a batch, or from a set of pairs. Take a look at this notebook to see example usage.

InferenceModel

from pytorch_metric_learning.utils.inference import InferenceModel InferenceModel(trunk, embedder=None, match_finder=None, normalize_embeddings=True, knn_func=None, data_device=None, dtype=None)

Parameters:

Methods:

`# initialize with a model im = InferenceModel(model)

pass in a dataset to serve as the search space for k-nn

im.train_knn(dataset)

add another dataset to the index

im.add_to_knn(dataset2)

get the 10 nearest neighbors of a query

distances, indices = im.get_nearest_neighbors(query, k=10)

determine if inputs are close to each other

is_match = im.is_match(x, y)

determine "is_match" pairwise for all elements in a batch

match_matrix = im.get_matches(x)

save and load the knn function (which is a faiss index by default)

im.save_knn_func("filename.index") im.load_knn_func("filename.index") `

MatchFinder

from pytorch_metric_learning.utils.inference import MatchFinder MatchFinder(distance=None, threshold=None)

Parameters:

FaissKNN

Uses the faiss library to compute k-nearest-neighbors

from pytorch_metric_learning.utils.inference import FaissKNN FaissKNN(reset_before=True, reset_after=True, index_init_fn=None, gpus=None)

Parameters:

Example:

`# use faiss.IndexFlatIP on 3 gpus knn_func = FaissKNN(index_init_fn=faiss.IndexFlatIP, gpus=[0,1,2])

query = query embeddings

k = the k in k-nearest-neighbors

reference = the embeddings to search

last argument is whether or not query and reference share datapoints

distances, indices = knn_func(query, k, references, False) `

FaissKMeans

Uses the faiss library to do k-means clustering.

from pytorch_metric_learning.utils.inference import FaissKMeans FaissKMeans(**kwargs)

Parameters:

Example:

`kmeans_func = FaissKMeans(niter=100, verbose=True, gpu=True)

cluster into 10 groups

cluster_assignments = kmeans_func(embeddings, 10) `

CustomKNN

Uses a distance function to determine similarity between datapoints, and then computes k-nearest-neighbors.

from pytorch_metric_learning.utils.inference import CustomKNN CustomKNN(distance, batch_size=None)

Parameters:

Example:

`from pytorch_metric_learning.distances import SNRDistance from pytorch_metric_learning.utils.inference import CustomKNN

knn_func = CustomKNN(SNRDistance()) distances, indices = knn_func(query, k, references, False) `