ray.train.trainer.BaseTrainer — Ray 2.45.0 (original) (raw)
class ray.train.trainer.BaseTrainer(*args, **kwargs)[source]#
Bases: ABC
Defines interface for distributed training on Ray.
Note: The base BaseTrainer
class cannot be instantiated directly. Only one of its subclasses can be used.
Note to developers: If a new trainer is added, please updateair/_internal/usage.py
.
How does a trainer work?
- First, initialize the Trainer. The initialization runs locally, so heavyweight setup should not be done in
__init__
. - Then, when you call
trainer.fit()
, the Trainer is serialized and copied to a remote Ray actor. The following methods are then called in sequence on the remote actor. trainer.setup()
: Any heavyweight Trainer setup should be specified here.trainer.training_loop()
: Executes the main training logic.- Calling
trainer.fit()
will return aray.result.Result
object where you can access metrics from your training run, as well as any checkpoints that may have been saved.
How do I create a new Trainer?
Subclass ray.train.trainer.BaseTrainer
, and override the training_loop
method, and optionally setup
.
import torch
from ray.train.trainer import BaseTrainer from ray import train, tune
class MyPytorchTrainer(BaseTrainer): def setup(self): self.model = torch.nn.Linear(1, 1) self.optimizer = torch.optim.SGD( self.model.parameters(), lr=0.1)
def training_loop(self):
# You can access any Trainer attributes directly in this method.
# self.datasets["train"] has already been
dataset = self.datasets["train"]
torch_ds = dataset.iter_torch_batches(dtypes=torch.float)
loss_fn = torch.nn.MSELoss()
for epoch_idx in range(10):
loss = 0
num_batches = 0
torch_ds = dataset.iter_torch_batches(
dtypes=torch.float, batch_size=2
)
for batch in torch_ds:
X = torch.unsqueeze(batch["x"], 1)
y = torch.unsqueeze(batch["y"], 1)
# Compute prediction error
pred = self.model(X)
batch_loss = loss_fn(pred, y)
# Backpropagation
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
loss += batch_loss.item()
num_batches += 1
loss /= num_batches
# Use Tune functions to report intermediate
# results.
train.report({"loss": loss, "epoch": epoch_idx})
Initialize the Trainer, and call Trainer.fit()
import ray train_dataset = ray.data.from_items( [{"x": i, "y": i} for i in range(10)]) my_trainer = MyPytorchTrainer(datasets={"train": train_dataset}) result = my_trainer.fit()
Parameters:
- scaling_config – Configuration for how to scale training.
- run_config – Configuration for the execution of the training run.
- datasets – Any Datasets to use for training. Use the key “train” to denote which dataset is the training dataset.
- metadata – Dict that should be made available via
train.get_context().get_metadata()
and incheckpoint.get_metadata()
for checkpoints saved from this Trainer. Must be JSON-serializable. - resume_from_checkpoint – A checkpoint to resume training from.
DeveloperAPI: This API may change across minor Ray releases.
Methods
as_trainable | Converts self to a tune.Trainable class. |
---|---|
can_restore | Checks whether a given directory contains a restorable Train experiment. |
fit | Runs training. |
preprocess_datasets | Deprecated. |
restore | Restores a Train experiment from a previously interrupted/failed run. |
setup | Called during fit() to perform initial setup on the Trainer. |
training_loop | Loop called by fit() to run training and report results to Tune. |