BaseModel — mmengine 0.10.7 documentation (original) (raw)
class mmengine.model.BaseModel(data_preprocessor=None, init_cfg=None)[source]¶
Base class for all algorithmic models.
BaseModel implements the basic functions of the algorithmic model, such as weights initialize, batch inputs preprocess(see more information inBaseDataPreprocessor), parse losses, and update model parameters.
Subclasses inherit from BaseModel only need to implement the forward method, which implements the logic to calculate loss and predictions, then can be trained in the runner.
Examples
@MODELS.register_module() class ToyModel(BaseModel):
def __init__(self): super().__init__() self.backbone = nn.Sequential() self.backbone.add_module('conv1', nn.Conv2d(3, 6, 5)) self.backbone.add_module('pool', nn.MaxPool2d(2, 2)) self.backbone.add_module('conv2', nn.Conv2d(6, 16, 5)) self.backbone.add_module('fc1', nn.Linear(16 * 5 * 5, 120)) self.backbone.add_module('fc2', nn.Linear(120, 84)) self.backbone.add_module('fc3', nn.Linear(84, 10)) self.criterion = nn.CrossEntropyLoss() def forward(self, batch_inputs, data_samples, mode='tensor'): data_samples = torch.stack(data_samples) if mode == 'tensor': return self.backbone(batch_inputs) elif mode == 'predict': feats = self.backbone(batch_inputs) predictions = torch.argmax(feats, 1) return predictions elif mode == 'loss': feats = self.backbone(batch_inputs) loss = self.criterion(feats, data_samples) return dict(loss=loss)
Parameters:
- data_preprocessor (dict, optional) – The pre-process config ofBaseDataPreprocessor.
- init_cfg (dict, optional) – The weight initialized config forBaseModule.
data_preprocessor¶
Used for pre-processing data sampled by dataloader to the format accepted byforward().
Type:
init_cfg¶
Initialization config dict.
Type:
dict, optional
Overrides this method to call BaseDataPreprocessor.cpu()additionally.
Returns:
The model itself.
Return type:
nn.Module
Overrides this method to call BaseDataPreprocessor.cuda()additionally.
Returns:
The model itself.
Return type:
nn.Module
Parameters:
device (int | str | device | None) –
abstract forward(inputs, data_samples=None, mode='tensor')[source]¶
Returns losses or predictions of training, validation, testing, and simple inference process.
forward
method of BaseModel is an abstract method, its subclasses must implement this method.
Accepts batch_inputs
and data_sample
processed bydata_preprocessor, and returns results according to mode arguments.
During non-distributed training, validation, and testing process,forward
will be called by BaseModel.train_step
,BaseModel.val_step
and BaseModel.test_step
directly.
During distributed data parallel training process,MMSeparateDistributedDataParallel.train_step
will first callDistributedDataParallel.forward
to enable automatic gradient synchronization, and then call forward
to get training loss.
Parameters:
- inputs (torch.Tensor) – batch input tensor collated bydata_preprocessor.
- data_samples (list, optional) – data samples collated by data_preprocessor.
- mode (str) –
mode should be one ofloss
,predict
andtensor
loss
: Called bytrain_step
and return lossdict
used for loggingpredict
: Called byval_step
andtest_step
and return list of results used for computing metric.tensor
: Called by custom use to getTensor
type results.
Returns:
- If
mode == loss
, return adict
of loss tensor used for backward and logging. - If
mode == predict
, return alist
of inference results. - If
mode == tensor
, return a tensor ortuple
of tensor ordict
of tensor for custom use.
Return type:
Overrides this method to call BaseDataPreprocessor.mlu()additionally.
Returns:
The model itself.
Return type:
nn.Module
Parameters:
device (int | str | device | None) –
Overrides this method to call BaseDataPreprocessor.musa()additionally.
Returns:
The model itself.
Return type:
nn.Module
Parameters:
device (int | str | device | None) –
Overrides this method to call BaseDataPreprocessor.npu()additionally.
Returns:
The model itself.
Return type:
nn.Module
Parameters:
device (int | str | device | None) –
Note
This generation of NPU(Ascend910) does not support the use of multiple cards in a single process, so the index here needs to be consistent with the default device
Parses the raw outputs (losses) of the network.
Parameters:
losses (dict) – Raw output of the network, which usually contain losses and other necessary information.
Returns:
There are two elements. The first is the loss tensor passed to optim_wrapper which may be a weighted sum of all losses, and the second is log_vars which will be sent to the logger.
Return type:
BaseModel
implements test_step
the same as val_step
.
Parameters:
data (dict or tuple or list) – Data sampled from dataset.
Returns:
The predictions of given data.
Return type:
Overrides this method to call BaseDataPreprocessor.to()additionally.
Returns:
The model itself.
Return type:
nn.Module
train_step(data, optim_wrapper)[source]¶
Implements the default model training process including preprocessing, model forward propagation, loss calculation, optimization, and back-propagation.
During non-distributed training. If subclasses do not override thetrain_step(), EpochBasedTrainLoop
orIterBasedTrainLoop
will call this method to update model parameters. The default parameter update process is as follows:
- Calls
self.data_processor(data, training=False)
to collect batch_inputs and corresponding data_samples(labels). - Calls
self(batch_inputs, data_samples, mode='loss')
to get raw loss - Calls
self.parse_losses
to getparsed_losses
tensor used to backward and dict of loss tensor used to log messages. - Calls
optim_wrapper.update_params(loss)
to update model.
Parameters:
- data (dict or tuple or list) – Data sampled from dataset.
- optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
Returns:
A dict
of tensor for logging.
Return type:
Dict[str, torch.Tensor]
Gets the predictions of given data.
Calls self.data_preprocessor(data, False)
andself(inputs, data_sample, mode='predict')
in order. Return the predictions which will be passed to evaluator.
Parameters:
data (dict or tuple or list) – Data sampled from dataset.
Returns:
The predictions of given data.
Return type: