GitHub - outerbounds/metaflow-torchrun (original) (raw)
This repository implements a plugin to run parallel Metaflow tasks as nodes in a torchrun job which can be submitted to AWS Batch or a Kubernetes cluster.
pip install metaflow-torchrun
from metaflow import FlowSpec, step, torchrun
...
class MyGPT(FlowSpec):
@step
def start(self):
self.next(self.torch_multinode, num_parallel=N_NODES)
@kubernetes(cpu=N_CPU, gpu=N_GPU, memory=MEMORY)
@torchrun
@step
def torch_multinode(self):
...
current.torch.run(
entrypoint="main.py", # No changes made to original script.
entrypoint_args = {"main-arg-1": "123", "main-arg-2": "777"},
nproc_per_node=1, # edge case of a torchrun arg user-facing.
)
...
...
metaflow-torchrun
is distributed under the Apache License.