JAX NeuronX Environment Variables — AWS Neuron Documentation (original) (raw)

This document is relevant for: Inf2, Trn1, Trn2, Trn3

JAX NeuronX Environment Variables#

Environment variables allow modifications to JAX NeuronX behavior without requiring code change to user script. It is recommended to set them in code or just before invoking the python process, such asNEURON_RT_VISIBLE_CORES=8 python3 <script> to avoid inadvertently changing behavior for other scripts. Environment variables specific to JAX Neuronx are:

NEURON_CC_FLAGS

XLA_FLAGS

NEURON_FORCE_PJRT_PLUGIN_REGISTRATION

NEURON_RUN_TRIVIAL_COMPUTATION_ON_CPU

NEURON_PJRT_PROCESSES_NUM_DEVICES

NEURON_PJRT_PROCESS_INDEX

NEURON_RT_STOCHASTIC_ROUNDING_EN [Neuron Runtime]

NEURON_RT_STOCHASTIC_ROUNDING_SEED [Neuron Runtime]

NEURON_RT_VISIBLE_CORES [Neuron Runtime]

Additional Neuron runtime environment variables are described in NeuronX Runtime Configuration.

This document is relevant for: Inf2, Trn1, Trn2, Trn3