A Method for Support ONNX Converting · Issue #192 · IDEA-Research/detrex (original) (raw)
Taking DN-DETR as example
1 add a dn_detr_onnx.py, you only need to change the forward
function.
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from detrex.layers import MLP, GenerateDNQueries, box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
from detrex.utils.misc import inverse_sigmoid
from detectron2.modeling import detector_postprocess
from detectron2.structures import Boxes, ImageList, Instances
class DNDETR_ONNX(nn.Module):
"""Implement DN-DETR in `DN-DETR: Dynamic Anchor Boxes are Better Queries for DETR
<https://arxiv.org/abs/2201.12329>`_
Args:
backbone (nn.Module): Backbone module for feature extraction.
in_features (List[str]): Selected backbone output features for transformer module.
in_channels (int): Dimension of the last feature in `in_features`.
position_embedding (nn.Module): Position encoding layer for generating position embeddings.
transformer (nn.Module): Transformer module used for further processing features and input queries.
embed_dim (int): Hidden dimension for transformer module.
num_classes (int): Number of total categories.
num_queries (int): Number of proposal dynamic anchor boxes in Transformer
criterion (nn.Module): Criterion for calculating the total losses.
aux_loss (bool): Whether to calculate auxiliary loss in criterion. Default: True.
pixel_mean (List[float]): Pixel mean value for image normalization.
Default: [123.675, 116.280, 103.530].
pixel_std (List[float]): Pixel std value for image normalization.
Default: [58.395, 57.120, 57.375].
freeze_anchor_box_centers (bool): If True, freeze the center param ``(x, y)`` for
the initialized dynamic anchor boxes in format ``(x, y, w, h)``
and only train ``(w, h)``. Default: True.
select_box_nums_for_evaluation (int): Select the top-k confidence predicted boxes for inference.
Default: 300.
denoising_groups (int): Number of groups for noised ground truths. Default: 5.
label_noise_prob (float): The probability of the label being noised. Default: 0.2.
box_noise_scale (float): Scaling factor for box noising. Default: 0.4.
with_indicator (bool): If True, add indicator in denoising queries part and matching queries part.
Default: True.
device (str): Training device. Default: "cuda".
"""
def __init__(
self,
backbone: nn.Module,
in_features: List[str],
in_channels: int,
position_embedding: nn.Module,
transformer: nn.Module,
embed_dim: int,
num_classes: int,
num_queries: int,
criterion: nn.Module,
aux_loss: bool = True,
pixel_mean: List[float] = [123.675, 116.280, 103.530],
pixel_std: List[float] = [58.395, 57.120, 57.375],
freeze_anchor_box_centers: bool = True,
select_box_nums_for_evaluation: int = 300,
denoising_groups: int = 5,
label_noise_prob: float = 0.2,
box_noise_scale: float = 0.4,
with_indicator: bool = True,
device="cuda",
):
super(DNDETR_ONNX, self).__init__()
# define backbone and position embedding module
self.backbone = backbone
self.in_features = in_features
self.position_embedding = position_embedding
# project the backbone output feature
# into the required dim for transformer block
self.input_proj = nn.Conv2d(in_channels, embed_dim, kernel_size=1)
# generate denoising label/box queries
self.denoising_generator = GenerateDNQueries(
num_queries=num_queries,
num_classes=num_classes + 1,
label_embed_dim=embed_dim,
denoising_groups=denoising_groups,
label_noise_prob=label_noise_prob,
box_noise_scale=box_noise_scale,
with_indicator=with_indicator,
)
self.denoising_groups = denoising_groups
self.label_noise_prob = label_noise_prob
self.box_noise_scale = box_noise_scale
# define leanable anchor boxes and transformer module
self.transformer = transformer
self.anchor_box_embed = nn.Embedding(num_queries, 4)
self.num_queries = num_queries
# whether to freeze the initilized anchor box centers during training
self.freeze_anchor_box_centers = freeze_anchor_box_centers
# define classification head and box head
self.class_embed = nn.Linear(embed_dim, num_classes)
self.bbox_embed = MLP(input_dim=embed_dim, hidden_dim=embed_dim, output_dim=4, num_layers=3)
self.num_classes = num_classes
# predict offsets to update anchor boxes after each decoder layer
# with shared box embedding head
# this is a hack implementation which will be modified in the future
self.transformer.decoder.bbox_embed = self.bbox_embed
# where to calculate auxiliary loss in criterion
self.aux_loss = aux_loss
self.criterion = criterion
# normalizer for input raw images
self.device = device
pixel_mean = torch.Tensor(pixel_mean).to(self.device).view(3, 1, 1)
pixel_std = torch.Tensor(pixel_std).to(self.device).view(3, 1, 1)
self.normalizer = lambda x: (x - pixel_mean) / pixel_std
# The total nums of selected boxes for evaluation
self.select_box_nums_for_evaluation = select_box_nums_for_evaluation
self.init_weights()
def init_weights(self):
"""Initialize weights for DN-DETR"""
if self.freeze_anchor_box_centers:
self.anchor_box_embed.weight.data[:, :2].uniform_(0, 1)
self.anchor_box_embed.weight.data[:, :2] = inverse_sigmoid(
self.anchor_box_embed.weight.data[:, :2]
)
self.anchor_box_embed.weight.data[:, :2].requires_grad = False
# init prior_prob setting for focal loss
prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob)
self.class_embed.bias.data = torch.ones(self.num_classes) * bias_value
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
def forward(self, batched_inputs):
"""Forward function of `DN-DETR` which excepts a list of dict as inputs.
Args:
batched_inputs (List[dict]): A list of instance dict, and each instance dict must consists of:
- dict["image"] (torch.Tensor): The unnormalized image tensor.
- dict["height"] (int): The original image height.
- dict["width"] (int): The original image width.
- dict["instance"] (detectron2.structures.Instances):
Image meta informations and ground truth boxes and labels during training.
Please refer to
https://detectron2.readthedocs.io/en/latest/modules/structures.html#detectron2.structures.Instances
for the basic usage of Instances.
Returns:
dict: Returns a dict with the following elements:
- dict["pred_logits"]: the classification logits for all queries (anchor boxes in DAB-DETR).
with shape ``[batch_size, num_queries, num_classes]``
- dict["pred_boxes"]: The normalized boxes coordinates for all queries in format
``(x, y, w, h)``. These values are normalized in [0, 1] relative to the size of
each individual image (disregarding possible padding). See PostProcess for information
on how to retrieve the unnormalized bounding box.
- dict["aux_outputs"]: Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
images = self.normalizer(batched_inputs)
# todo: remove this part, as mask is not needed for batch=1.
batch_size, _, H, W = images.shape
img_masks = images.new_zeros(batch_size, H, W)
features = self.backbone(images)[self.in_features[-1]]
features = self.input_proj(features)
img_masks = F.interpolate(img_masks[None], size=features.shape[-2:]).to(torch.bool)[0]
# img_masks = F.interpolate(img_masks[None], scale_factor=(1/32, 1/32)).to(torch.bool)[0]
pos_embed = self.position_embedding(img_masks)
targets = None
# for vallina dn-detr, label queries in the matching part is encoded as "no object" (the last class)
# in the label encoder.
matching_label_query = self.denoising_generator.label_encoder(
torch.tensor(self.num_classes).to(self.device)
).repeat(self.num_queries, 1)
indicator_for_matching_part = torch.zeros([self.num_queries, 1]).to(self.device)
matching_label_query = torch.cat(
[matching_label_query, indicator_for_matching_part], 1
).repeat(batch_size, 1, 1) # (num_q, emd-1) + (num_q, 1) -> (num_q, emd) -> (bs, num_q, 1)
matching_box_query = self.anchor_box_embed.weight.repeat(batch_size, 1, 1) #(bs, num_q, 4)
if targets is None:
input_label_query = matching_label_query.transpose(0, 1) # (num_queries, bs, embed_dim)
input_box_query = matching_box_query.transpose(0, 1) # (num_queries, bs, 4)
attn_mask = None
denoising_groups = self.denoising_groups
max_gt_num_per_image = 0
hidden_states, reference_boxes = self.transformer(
features,
img_masks,
input_box_query,
pos_embed,
target=input_label_query,
attn_mask=[attn_mask, None], # None mask for cross attention
)
# Calculate output coordinates and classes.
reference_boxes = inverse_sigmoid(reference_boxes[-1]) # (bs, num_q, 4)
anchor_box_offsets = self.bbox_embed(hidden_states[-1]) # (bs, num_q, emd) -> # (bs, num_q, 4)
outputs_coord = (reference_boxes + anchor_box_offsets).sigmoid()
outputs_class = self.class_embed(hidden_states[-1]) # (bs, num_q, emd) -> # (bs, num_q, 1)
# no need for denoising post process, as only matching part remained.
# return last layer state, so take index=0
box_cls = outputs_class[0] # (1, num_q, 1) -> (num_q, 1)
box_pred = outputs_coord[0] # (1, num_q, 4) -> (num_q, 4)
# bs = 1 for onnx converting, so just take the index = 0 for simplification.
out_cls = box_cls.sigmoid()
out_box = box_cxcywh_to_xyxy(box_pred) # convert to xyxy format, (num_q, 4)
return out_cls, out_box
###### The following content is omitted #######
2 add new_file dn_detr_r50_onnx.py, change DNDETR
to DNDETR_ONNX
.
import torch.nn as nn
from detrex.layers import PositionEmbeddingSine
from detrex.modeling import HungarianMatcher
from detectron2.modeling.backbone import ResNet, BasicStem
from detectron2.config import LazyCall as L
from projects.dn_detr.modeling import (
DNDETR_ONNX,
DNDetrTransformerEncoder,
DNDetrTransformerDecoder,
DNDetrTransformer,
DNCriterion,
)
model = L(DNDETR_ONNX)(
backbone=L(ResNet)(
stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"),
stages=L(ResNet.make_default_stages)(
depth=50,
stride_in_1x1=False,
norm="FrozenBN",
),
out_features=["res5"],
......
###### The following content is omitted #######
3 add new_file dn_detr_r50_50ep_onnx.py
from detrex.config import get_config
from .models.dn_detr_r50_onnx import model
###### The following content is omitted #######
4 add new_file torch2onnx.py
in the tools foder.
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
"""
Training script using the new "LazyConfig" python config files.
This scripts reads a given python config file and runs the training or evaluation.
It can be used to train any models or dataset as long as they can be
instantiated by the recursive construction defined in the given config file.
Besides lazy construction of models, dataloader, etc., this scripts expects a
few common configuration parameters currently defined in "configs/common/train.py".
To add more complicated training logic, you can easily add other configs
in the config file and implement a new train_net.py to handle them.
"""
import logging
import os
import sys
import time
import torch
import onnx
import numpy as np
import io
from torch.nn.parallel import DataParallel, DistributedDataParallel
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import LazyConfig, instantiate
from detectron2.engine import (
SimpleTrainer,
default_argument_parser,
default_setup,
default_writers,
hooks,
launch,
)
from detectron2.engine.defaults import create_ddp_model
from detectron2.evaluation import inference_on_dataset, print_csv_format
from detectron2.utils import comm
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
logger = logging.getLogger("detrex")
class Trainer(SimpleTrainer):
"""
We've combine Simple and AMP Trainer together.
"""
def __init__(
self,
model,
dataloader,
optimizer,
amp=False,
clip_grad_params=None,
grad_scaler=None,
):
super().__init__(model=model, data_loader=dataloader, optimizer=optimizer)
unsupported = "AMPTrainer does not support single-process multi-device training!"
if isinstance(model, DistributedDataParallel):
assert not (model.device_ids and len(model.device_ids) > 1), unsupported
assert not isinstance(model, DataParallel), unsupported
if amp:
if grad_scaler is None:
from torch.cuda.amp import GradScaler
grad_scaler = GradScaler()
self.grad_scaler = grad_scaler
# set True to use amp training
self.amp = amp
# gradient clip hyper-params
self.clip_grad_params = clip_grad_params
def run_step(self):
"""
Implement the standard training logic described above.
"""
assert self.model.training, "[Trainer] model was changed to eval mode!"
assert torch.cuda.is_available(), "[Trainer] CUDA is required for AMP training!"
from torch.cuda.amp import autocast
start = time.perf_counter()
"""
If you want to do something with the data, you can wrap the dataloader.
"""
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
"""
If you want to do something with the losses, you can wrap the model.
"""
loss_dict = self.model(data)
with autocast(enabled=self.amp):
if isinstance(loss_dict, torch.Tensor):
losses = loss_dict
loss_dict = {"total_loss": loss_dict}
else:
losses = sum(loss_dict.values())
"""
If you need to accumulate gradients or do something similar, you can
wrap the optimizer with your custom `zero_grad()` method.
"""
self.optimizer.zero_grad()
if self.amp:
self.grad_scaler.scale(losses).backward()
if self.clip_grad_params is not None:
self.grad_scaler.unscale_(self.optimizer)
self.clip_grads(self.model.parameters())
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
else:
losses.backward()
if self.clip_grad_params is not None:
self.clip_grads(self.model.parameters())
self.optimizer.step()
self._write_metrics(loss_dict, data_time)
def clip_grads(self, params):
params = list(filter(lambda p: p.requires_grad and p.grad is not None, params))
if len(params) > 0:
return torch.nn.utils.clip_grad_norm_(
parameters=params,
**self.clip_grad_params,
)
def do_test(cfg, model):
if "evaluator" in cfg.dataloader:
ret = inference_on_dataset(
model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator)
)
print_csv_format(ret)
return ret
def do_train(args, cfg):
"""
Args:
cfg: an object with the following attributes:
model: instantiate to a module
dataloader.{train,test}: instantiate to dataloaders
dataloader.evaluator: instantiate to evaluator for test set
optimizer: instantaite to an optimizer
lr_multiplier: instantiate to a fvcore scheduler
train: other misc config defined in `configs/common/train.py`, including:
output_dir (str)
init_checkpoint (str)
amp.enabled (bool)
max_iter (int)
eval_period, log_period (int)
device (str)
checkpointer (dict)
ddp (dict)
"""
model = instantiate(cfg.model)
logger = logging.getLogger("detectron2")
logger.info("Model:\n{}".format(model))
model.to(cfg.train.device)
cfg.optimizer.params.model = model
optim = instantiate(cfg.optimizer)
train_loader = instantiate(cfg.dataloader.train)
model = create_ddp_model(model, **cfg.train.ddp)
trainer = Trainer(
model=model,
dataloader=train_loader,
optimizer=optim,
amp=cfg.train.amp.enabled,
clip_grad_params=cfg.train.clip_grad.params if cfg.train.clip_grad.enabled else None,
)
checkpointer = DetectionCheckpointer(
model,
cfg.train.output_dir,
trainer=trainer,
)
trainer.register_hooks(
[
hooks.IterationTimer(),
hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
if comm.is_main_process()
else None,
hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
hooks.PeriodicWriter(
default_writers(cfg.train.output_dir, cfg.train.max_iter),
period=cfg.train.log_period,
)
if comm.is_main_process()
else None,
]
)
checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)
if args.resume and checkpointer.has_checkpoint():
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration
start_iter = trainer.iter + 1
else:
start_iter = 0
trainer.train(start_iter, cfg.train.max_iter)
def convert_onnx(net, output, opset=11, simplify=False):
assert isinstance(net, torch.nn.Module)
net.eval()
import cv2
import numpy as np
new_img = torch.rand(1, 3, 640, 640)
# dynamic input shape
# torch.onnx.export(net, new_img, output, input_names=["inputs"],
# output_names=["pred_logits", "pred_boxes"],
# do_constant_folding=True, opset_version=opset,
# # dynamic_axes={"inputs": {2:'H',3:'W'}, "pred_logits": {2:'H',3:'W'}, "pred_boxes": {2:'H',3:'W'}},)
# dynamic_axes={"inputs": {2:'H',3:'W'},},)
torch.onnx.export(net, new_img, output, input_names=["inputs"],
output_names=["pred_logits", "pred_boxes"], opset_version=opset)
def main(args):
cfg = LazyConfig.load(args.config_file)
cfg = LazyConfig.apply_overrides(cfg, args.opts)
default_setup(cfg, args)
if args.eval_only:
model = instantiate(cfg.model)
# model.to(cfg.train.device)
# model = create_ddp_model(model)
DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
# print(do_test(cfg, model))
# convert_onnx(model, './test_dab_detr.onnx', opset=12)
# convert_onnx(model, './test_detr_r18_ds.onnx', opset=12)
save_onnx_path = os.path.join('/'.join(cfg.train.init_checkpoint.split('/')[0:-1]), 'model_final.onnx')
print(save_onnx_path)
convert_onnx(model, save_onnx_path, opset=12)
else:
do_train(args, cfg)
if __name__ == "__main__":
args = default_argument_parser().parse_args()
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)
5 run with command
python tools/torch2onnx.py --config-file projects/dn_detr/configs/dn_detr_r50_50ep_onnx.py --eval-only train.init_checkpoint=$<YOUR_MODEL_PATH>
It worked for me. The torch results and onnx results are the same.