detectron2.data — detectron2 0.6 documentation (original) (raw)

detectron2

detectron2.data. DatasetCatalog(dict)

A global dictionary that stores information about the datasets and how to obtain them.

It contains a mapping from strings (which are names that identify a dataset, e.g. “coco_2014_train”) to a function which parses the dataset and returns the samples in the format of list[dict].

The returned dicts should be in Detectron2 Dataset format (See DATASETS.md for details) if used with the data loader functionalities in data/build.py,data/detection_transform.py.

The purpose of having this catalog is to make it easy to choose different datasets, by just using the strings in the config.

DatasetCatalog. register(name, func)

Parameters

DatasetCatalog. get(name)

Call the registered function and return its results.

Parameters

name (str) – the name that identifies a dataset, e.g. “coco_2014_train”.

Returns

list[dict] – dataset annotations.

detectron2.data. MetadataCatalog(dict)

MetadataCatalog is a global dictionary that provides access toMetadata of a given dataset.

The metadata associated with a certain name is a singleton: once created, the metadata will stay alive and will be returned by future calls to get(name).

It’s like global variables, so don’t abuse it. It’s meant for storing knowledge that’s constant and shared across the execution of the program, e.g.: the class names in COCO.

MetadataCatalog. get(name)

Parameters

name (str) – name of a dataset (e.g. coco_2014_train).

Returns

Metadata – The Metadata instance associated with this name, or create an empty one if none is available.

detectron2.data. build_detection_test_loader(dataset: Union[List[Any], torch.utils.data.Dataset], *, mapper: Callable[[Dict[str, Any]], Any], sampler: Optional[torch.utils.data.Sampler] = None, batch_size: int = 1, num_workers: int = 0, collate_fn: Optional[Callable[[List[Any]], Any]] = None) → torch.utils.data.DataLoader[source]

Similar to build_detection_train_loader, with default batch size = 1, and sampler = InferenceSampler. This sampler coordinates all workers to produce the exact set of all samples.

Parameters

Returns

DataLoader – a torch DataLoader, that loads the given detection dataset, with test-time transformation and batching.

Examples:

data_loader = build_detection_test_loader( DatasetRegistry.get("my_test"), mapper=DatasetMapper(...))

or, instantiate with a CfgNode:

data_loader = build_detection_test_loader(cfg, "my_test")

detectron2.data. build_detection_train_loader(dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0, collate_fn=None, prefetch_factor=None, persistent_workers=False, pin_memory=False)[source]

Build a dataloader for object detection with some default features.

Parameters

Returns

torch.utils.data.DataLoader – a dataloader. Each output from it is a list[mapped_element] of lengthtotal_batch_size / num_workers, where mapped_element is produced by the mapper.

detectron2.data. get_detection_dataset_dicts(names, filter_empty=True, min_keypoints=0, proposal_files=None, check_consistency=True)[source]

Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.

Parameters

Returns

list[dict] – a list of dicts following the standard dataset dict format.

detectron2.data. load_proposals_into_dataset(dataset_dicts, proposal_file)[source]

Load precomputed object proposals into the dataset.

The proposal file should be a pickled dict with the following keys:

Parameters

Returns

list[dict] – the same format as dataset_dicts, but added proposal field.

detectron2.data. print_instances_class_histogram(dataset_dicts, class_names)[source]

Parameters

class detectron2.data. Metadata[source]

Bases: types.SimpleNamespace

A class that supports simple attribute setter/getter. It is intended for storing metadata of a dataset and make it accessible globally.

Examples:

somewhere when you load the data:

MetadataCatalog.get("mydataset").thing_classes = ["person", "dog"]

somewhere when you print statistics or visualize:

classes = MetadataCatalog.get("mydataset").thing_classes

name: str = 'N/A'

as_dict()[source]

Returns all the metadata as a dict. Note that modifications to the returned dict will not reflect on the Metadata object.

set(**kwargs)[source]

Set multiple metadata with kwargs.

get(key, default=None)[source]

Access an attribute and return its value if exists. Otherwise return default.

class detectron2.data. DatasetFromList(*args, **kwds)[source]

Bases: torch.utils.data.Dataset

Wrap a list to a torch Dataset. It produces elements of the list as data.

__init__(lst: list, copy: bool = True, serialize: Union[bool, Callable] = True)[source]

Parameters

class detectron2.data. MapDataset(dataset, map_func)[source]

Bases: torch.utils.data.Dataset

Map a function over the elements in a dataset.

__init__(dataset, map_func)[source]

Parameters

class detectron2.data. ToIterableDataset(*args, **kwds)[source]

Bases: torch.utils.data.IterableDataset

Convert an old indices-based (also called map-style) dataset to an iterable-style dataset.

__init__(dataset: torch.utils.data.Dataset, sampler: torch.utils.data.Sampler, shard_sampler: bool = True, shard_chunk_size: int = 1)[source]

Parameters

class detectron2.data. DatasetMapper(*args, **kwargs)[source]

Bases: object

A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by the model.

This is the default callable to be used to map your dataset dict into training data. You may need to follow it to implement your own one for customized logic, such as a different way to read or transform images. See Dataloader for details.

The callable currently does the following:

  1. Read the image from “file_name”
  2. Applies cropping/geometric transforms to the image and annotations
  3. Prepare data and annotations to Tensor and Instances

__init__(is_train: bool, *, augmentations: List[Union[detectron2.data.transforms.Augmentation, detectron2.data.transforms.Transform]], image_format: str, use_instance_mask: bool = False, use_keypoint: bool = False, instance_mask_format: str = 'polygon', keypoint_hflip_indices: Optional[numpy.ndarray] = None, precomputed_proposal_topk: Optional[int] = None, recompute_boxes: bool = False)[source]

NOTE: this interface is experimental.

Parameters

classmethod from_config(cfg, is_train: bool = True)[source]

__call__(dataset_dict)[source]

Parameters

dataset_dict (dict) – Metadata of one image, in Detectron2 Dataset format.

Returns

dict – a format that builtin models in detectron2 accept

detectron2.data.detection_utils module

Common data processing utilities that are used in a typical object detection data pipeline.

exception detectron2.data.detection_utils. SizeMismatchError[source]

Bases: ValueError

When loaded image has difference width/height compared with annotation.

detectron2.data.detection_utils. convert_image_to_rgb(image, format)[source]

Convert an image from given format to RGB.

Parameters

Returns

(np.ndarray) – (H,W,3) RGB image in 0-255 range, can be either float or uint8

detectron2.data.detection_utils. check_image_size(dataset_dict, image)[source]

Raise an error if the image does not match the size specified in the dict.

detectron2.data.detection_utils. transform_proposals(dataset_dict, image_shape, transforms, *, proposal_topk, min_box_size=0)[source]

Apply transformations to the proposals in dataset_dict, if any.

Parameters

The input dict is modified in-place, with abovementioned keys removed. A new key “proposals” will be added. Its value is an Instancesobject which contains the transformed proposals in its field “proposal_boxes” and “objectness_logits”.

detectron2.data.detection_utils. transform_instance_annotations(annotation, transforms, image_size, *, keypoint_hflip_indices=None)[source]

Apply transforms to box, segmentation and keypoints annotations of a single instance.

It will use transforms.apply_box for the box, andtransforms.apply_coords for segmentation polygons & keypoints. If you need anything more specially designed for each data structure, you’ll need to implement your own version of this function or the transforms.

Parameters

Returns

dict – the same input dict with fields “bbox”, “segmentation”, “keypoints” transformed according to transforms. The “bbox_mode” field will be set to XYXY_ABS.

detectron2.data.detection_utils. annotations_to_instances(annos, image_size, mask_format='polygon')[source]

Create an Instances object used by the models, from instance annotations in the dataset dict.

Parameters

Returns

Instances – It will contain fields “gt_boxes”, “gt_classes”, “gt_masks”, “gt_keypoints”, if they can be obtained from annos. This is the format that builtin models expect.

detectron2.data.detection_utils. annotations_to_instances_rotated(annos, image_size)[source]

Create an Instances object used by the models, from instance annotations in the dataset dict. Compared to annotations_to_instances, this function is for rotated boxes only

Parameters

Returns

Instances – Containing fields “gt_boxes”, “gt_classes”, if they can be obtained from annos. This is the format that builtin models expect.

detectron2.data.detection_utils. build_augmentation(cfg, is_train)[source]

Create a list of default Augmentation from config. Now it includes resizing and flipping.

Returns

list[Augmentation]

detectron2.data.detection_utils. create_keypoint_hflip_indices(dataset_names: Union[str, List[str]]) → List[int][source]

Parameters

dataset_names – list of dataset names

Returns

list[int] – a list of size=#keypoints, storing the horizontally-flipped keypoint indices.

detectron2.data.detection_utils. filter_empty_instances(instances, by_box=True, by_mask=True, box_threshold=1e-05, return_mask=False)[source]

Filter out empty instances in an Instances object.

Parameters

Returns

Instances – the filtered instances. tensor[bool], optional: boolean mask of filtered instances

detectron2.data.detection_utils. read_image(file_name, format=None)[source]

Read an image into the given format. Will apply rotation and flipping if the image has such exif information.

Parameters

Returns

image (np.ndarray) – an HWC image in the given format, which is 0-255, uint8 for supported image modes in PIL or “BGR”; float (0-1 for Y) for YUV-BT.601.

detectron2.data.datasets module

detectron2.data.datasets. load_coco_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None)[source]

Load a json file with COCO’s instances annotation format. Currently supports instance detection, instance segmentation, and person keypoints annotations.

Parameters

Returns

list[dict] – a list of dicts in Detectron2 standard dataset dicts format (SeeUsing Custom Datasets ) when dataset_name is not None. If dataset_name is None, the returned category_ids may be incontiguous and may not conform to the Detectron2 standard format.

Notes

  1. This function does not read the image files. The results do not have the “image” field.

detectron2.data.datasets. load_sem_seg(gt_root, image_root, gt_ext='png', image_ext='jpg')[source]

Load semantic segmentation datasets. All files under “gt_root” with “gt_ext” extension are treated as ground truth annotations and all files under “image_root” with “image_ext” extension as input images. Ground truth and input images are matched using file paths relative to “gt_root” and “image_root” respectively without taking into account file extensions. This works for COCO as well as some other datasets.

Parameters

Returns

list[dict] – a list of dicts in detectron2 standard format without instance-level annotation.

Notes

  1. This function does not read the image and ground truth files. The results do not have the “image” and “sem_seg” fields.

detectron2.data.datasets. register_coco_instances(name, metadata, json_file, image_root)[source]

Register a dataset in COCO’s json annotation format for instance detection, instance segmentation and keypoint detection. (i.e., Type 1 and 2 in http://cocodataset.org/#format-data.instances*.json and person_keypoints*.json in the dataset).

This is an example of how to register a new dataset. You can do something similar to this function, to register new datasets.

Parameters

detectron2.data.datasets. convert_to_coco_json(dataset_name, output_file, allow_cached=True)[source]

Converts dataset into COCO format and saves it to a json file. dataset_name must be registered in DatasetCatalog and in detectron2’s standard format.

Parameters

detectron2.data.datasets. register_coco_panoptic(name, metadata, image_root, panoptic_root, panoptic_json, instances_json=None)[source]

Register a “standard” version of COCO panoptic segmentation dataset named name. The dictionaries in this registered dataset follows detectron2’s standard format. Hence it’s called “standard”.

Parameters

detectron2.data.datasets. register_coco_panoptic_separated(name, metadata, image_root, panoptic_root, panoptic_json, sem_seg_root, instances_json)[source]

Register a “separated” version of COCO panoptic segmentation dataset named name. The annotations in this registered dataset will contain both instance annotations and semantic annotations, each with its own contiguous ids. Hence it’s called “separated”.

It follows the setting used by the PanopticFPN paper:

  1. The instance annotations directly come from polygons in the COCO instances annotation task, rather than from the masks in the COCO panoptic annotations.
    The two format have small differences: Polygons in the instance annotations may have overlaps. The mask annotations are produced by labeling the overlapped polygons with depth ordering.
  2. The semantic annotations are converted from panoptic annotations, where all “things” are assigned a semantic id of 0. All semantic categories will therefore have ids in contiguous range [1, #stuff_categories].

This function will also register a pure semantic segmentation dataset named name + '_stuffonly'.

Parameters

detectron2.data.datasets. load_lvis_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None)[source]

Load a json file in LVIS’s annotation format.

Parameters

Returns

list[dict] – a list of dicts in Detectron2 standard format. (SeeUsing Custom Datasets )

Notes

  1. This function does not read the image files. The results do not have the “image” field.

detectron2.data.datasets. register_lvis_instances(name, metadata, json_file, image_root)[source]

Register a dataset in LVIS’s json annotation format for instance detection and segmentation.

Parameters

detectron2.data.datasets. get_lvis_instances_meta(dataset_name)[source]

Load LVIS metadata.

Parameters

dataset_name (str) – LVIS dataset name without the split name (e.g., “lvis_v0.5”).

Returns

dict – LVIS metadata with keys: thing_classes

detectron2.data.datasets. load_voc_instances(dirname: str, split: str, class_names: Union[List[str], Tuple[str, …]])[source]

Load Pascal VOC detection annotations to Detectron2 format.

Parameters

detectron2.data.datasets. register_pascal_voc(name, dirname, split, year, class_names='aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')[source]

detectron2.data.samplers module

class detectron2.data.samplers. TrainingSampler(*args, **kwds)[source]

Bases: torch.utils.data.Sampler

In training, we only care about the “infinite stream” of training data. So this sampler produces an infinite stream of indices and all workers cooperate to correctly shuffle the indices and sample different indices.

The samplers in each worker effectively produces indices[worker_id::num_workers]where indices is an infinite stream of indices consisting ofshuffle(range(size)) + shuffle(range(size)) + … (if shuffle is True) or range(size) + range(size) + … (if shuffle is False)

Note that this sampler does not shard based on pytorch DataLoader worker id. A sampler passed to pytorch DataLoader is used only with map-style dataset and will not be executed inside workers. But if this sampler is used in a way that it gets execute inside a dataloader worker, then extra work needs to be done to shard its outputs based on worker id. This is required so that workers don’t produce identical data.ToIterableDataset implements this logic. This note is true for all samplers in detectron2.

__init__(size: int, shuffle: bool = True, seed: Optional[int] = None)[source]

Parameters

class detectron2.data.samplers. RandomSubsetTrainingSampler(*args, **kwds)[source]

Bases: detectron2.data.samplers.distributed_sampler.TrainingSampler

Similar to TrainingSampler, but only sample a random subset of indices. This is useful when you want to estimate the accuracy vs data-number curves by

training the model with different subset_ratio.

__init__(size: int, subset_ratio: float, shuffle: bool = True, seed_shuffle: Optional[int] = None, seed_subset: Optional[int] = None)[source]

Parameters

class detectron2.data.samplers. InferenceSampler(*args, **kwds)[source]

Bases: torch.utils.data.Sampler

Produce indices for inference across all workers. Inference needs to run on the __exact__ set of samples, therefore when the total number of samples is not divisible by the number of workers, this sampler produces different number of samples on different workers.

__init__(size: int)[source]

Parameters

size (int) – the total number of data of the underlying dataset to sample from

class detectron2.data.samplers. RepeatFactorTrainingSampler(*args, **kwds)[source]

Bases: torch.utils.data.Sampler

Similar to TrainingSampler, but a sample may appear more times than others based on its “repeat factor”. This is suitable for training on class imbalanced datasets like LVIS.

__init__(repeat_factors, *, shuffle=True, seed=None)[source]

Parameters

static repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh)[source]

Compute (fractional) per-image repeat factors based on category frequency. The repeat factor for an image is a function of the frequency of the rarest category labeled in that image. The “frequency of category c” in [0, 1] is defined as the fraction of images in the training set (without repeats) in which category c appears. See LVIS: A Dataset for Large Vocabulary Instance Segmentation (>= v2) Appendix B.2.

Parameters

Returns

torch.Tensor – the i-th element is the repeat factor for the dataset image at index i.