ray.train.trainer.BaseTrainer — Ray 2.45.0 (original) (raw)

class ray.train.trainer.BaseTrainer(*args, **kwargs)[source]#

Bases: ABC

Defines interface for distributed training on Ray.

Note: The base BaseTrainer class cannot be instantiated directly. Only one of its subclasses can be used.

Note to developers: If a new trainer is added, please updateair/_internal/usage.py.

How does a trainer work?

How do I create a new Trainer?

Subclass ray.train.trainer.BaseTrainer, and override the training_loopmethod, and optionally setup.

import torch

from ray.train.trainer import BaseTrainer from ray import train, tune

class MyPytorchTrainer(BaseTrainer): def setup(self): self.model = torch.nn.Linear(1, 1) self.optimizer = torch.optim.SGD( self.model.parameters(), lr=0.1)

def training_loop(self):
    # You can access any Trainer attributes directly in this method.
    # self.datasets["train"] has already been
    dataset = self.datasets["train"]

    torch_ds = dataset.iter_torch_batches(dtypes=torch.float)
    loss_fn = torch.nn.MSELoss()

    for epoch_idx in range(10):
        loss = 0
        num_batches = 0
        torch_ds = dataset.iter_torch_batches(
            dtypes=torch.float, batch_size=2
        )
        for batch in torch_ds:
            X = torch.unsqueeze(batch["x"], 1)
            y = torch.unsqueeze(batch["y"], 1)
            # Compute prediction error
            pred = self.model(X)
            batch_loss = loss_fn(pred, y)

            # Backpropagation
            self.optimizer.zero_grad()
            batch_loss.backward()
            self.optimizer.step()

            loss += batch_loss.item()
            num_batches += 1
        loss /= num_batches

        # Use Tune functions to report intermediate
        # results.
        train.report({"loss": loss, "epoch": epoch_idx})

Initialize the Trainer, and call Trainer.fit()

import ray train_dataset = ray.data.from_items( [{"x": i, "y": i} for i in range(10)]) my_trainer = MyPytorchTrainer(datasets={"train": train_dataset}) result = my_trainer.fit()

Parameters:

DeveloperAPI: This API may change across minor Ray releases.

Methods

as_trainable Converts self to a tune.Trainable class.
can_restore Checks whether a given directory contains a restorable Train experiment.
fit Runs training.
preprocess_datasets Deprecated.
restore Restores a Train experiment from a previously interrupted/failed run.
setup Called during fit() to perform initial setup on the Trainer.
training_loop Loop called by fit() to run training and report results to Tune.