GitHub - pytorch-labs/float8_experimental at cb55df259cfb22a856ca92107a778343edea5fc7 (original) (raw)
float8_experimental
This is an early version of a library for accelerating training with float8 in native PyTorch according to the recipes laid out in https://arxiv.org/pdf/2209.05433.pdf. The codebase strives to stay small, easily hackable, and debuggable with native PyTorch tooling.torch.compile
is supported out of the box. With torch.compile
on, initial results show throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs.
⚠️ See the feature tracker for upcoming features. Key features such as weight cast recomputation in backward and large scale distributed support are not ready yet.
⚠️ Backwards compatibility is not guaranteed at this point. The codebase is in active development and will change rapidly.
installation
⚠️ For now, use the latest PyTorch nightly for best results with torch.compile.
pip install .
Optionally install editable
pip install -e .
Optionally Install dev tooling
pip install -e ".[dev]"
User API
We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details.
float8 linear with dynamic scaling
from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, ) from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
create model
m = Model(...)
convert all torch.nn.Linear
modules to Float8DynamicLinear
swap_linear_with_float8_linear(m, Float8DynamicLinear)
optional: use FSDP
model = FSDP(model, use_orig_params=True)
optional: enable torch.compile for improved performance
m = torch.compile(m)
train/finetune (not shown)
float8 linear with delayed scaling
from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) from float8_experimental.float8_linear import Float8Linear
create model
m = Model(...)
convert all torch.nn.Linear
modules to Float8Linear
swap_linear_with_float8_linear(m, Float8Linear)
optional: use FSDP. Note that workarounds gated with config.enable_amax_init and
config.enable_pre_and_post_forward are needed for autocast + compile + FSDP + float8 to work
from float8_experimental import config config.enable_amax_init = False # only needed for autocast + compile + FSDP + float8 delayed config.enable_pre_and_post_forward = False # only needed for autocast + compile + FSDP + float8 delayed model = FSDP(model, use_orig_params=True)
optional: enable torch.compile for improved performance
m = torch.compile(m)
toy training loop
for _ in range(N_ITER): optimizer.zero_grad() y = m(x) y.sum().backward()
# specific to float8 with delayed scaling: separate step to sync scales/amaxes
# in the future, this may move to a context manager
sync_float8_amax_and_scale_history(model)
optimizer.step()
🧭 Code Organization
float8_experimental/float8_linear.py
Float8Linear
(main user facing entry point for delayed scaling)
float8_experimental/float8_dynamic_linear.py
Float8DynamicLinear
(main user facing entry point for dynamic scaling)
float8_experimental/float8_tensor.py
Float8Tensor
, which allowsFloat8Linear
to abide by thex.dtype == x.grad.dtype
restrictionScaledMMConfig
defines the semantics for matmul in the forward and backwards pass
Testing
run single-GPU unit tests
pytest test/test_base.py
run a single-GPU integration test on SAM
pytest test/test_sam.py
run single-GPU compile tests
pytest test/test_compile.py
run a two-GPU integration test on FSDP
./test/test_fsdp.sh
run integration tests for TP/SP (outdated)
./test/test_tp.sh
run all of these tests
./test/test_everything.sh
Benchmarking
benchmark the torch._scaled_mm function on LLaMa 2 70B shapes
./benchmarks/bench_matmul.py
benchmark fw/bw of Linear
and Float8Linear
on LLaMa 2 70B shapes
make sure to turn on torch.compile to get the best performance
./benchmarks/bench_linear_float8.py -o ../tmp/test.txt --compile
License
PyTorch has a BSD 3-Clause License, as found in the LICENSE file.