torch.profiler — PyTorch 2.7 documentation (original) (raw)

Overview

PyTorch Profiler is a tool that allows the collection of performance metrics during training and inference. Profiler’s context manager API can be used to better understand what model operators are the most expensive, examine their input shapes and stack traces, study device kernel activity and visualize the execution trace.

Note

An earlier version of the API in torch.autograd module is considered legacy and will be deprecated.

API Reference

class torch.profiler._KinetoProfile(*, activities=None, record_shapes=False, profile_memory=False, with_stack=False, with_flops=False, with_modules=False, experimental_config=None, execution_trace_observer=None, acc_events=False, custom_trace_id_callback=None)[source][source]

Low-level profiler wrap the autograd profile

Parameters

Note

This API is experimental and subject to change in the future.

Enabling shape and stack tracing results in additional overhead. When record_shapes=True is specified, profiler will temporarily hold references to the tensors; that may further prevent certain optimizations that depend on the reference count and introduce extra tensor copies.

add_metadata(key, value)[source][source]

Adds a user defined metadata with a string key and a string value into the trace file

add_metadata_json(key, value)[source][source]

Adds a user defined metadata with a string key and a valid json value into the trace file

events()[source][source]

Returns the list of unaggregated profiler events, to be used in the trace callback or after the profiling is finished

export_chrome_trace(path)[source][source]

Exports the collected trace in Chrome JSON format. If kineto is enabled, only last cycle in schedule is exported.

export_memory_timeline(path, device=None)[source][source]

Export memory event information from the profiler collected tree for a given device, and export a timeline plot. There are 3 exportable files using export_memory_timeline, each controlled by thepath’s suffix.

Output: Memory timeline written as gzipped JSON, JSON, or HTML.

export_stacks(path, metric='self_cpu_time_total')[source][source]

Save stack traces to a file

Parameters

key_averages(group_by_input_shape=False, group_by_stack_n=0, group_by_overload_name=False)[source][source]

Averages events, grouping them by operator name and (optionally) input shapes, stack and overload name.

Note

To use shape/stack functionality make sure to set record_shapes/with_stack when creating profiler context manager.

preset_metadata_json(key, value)[source][source]

Preset a user defined metadata when the profiler is not started and added into the trace file later. Metadata is in the format of a string key and a valid json value

toggle_collection_dynamic(enable, activities)[source][source]

Toggle collection of activities on/off at any point of collection. Currently supports toggling Torch Ops (CPU) and CUDA activity supported in Kineto

Parameters

activities (iterable) – list of activity groups to use in profiling, supported values:torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA

Examples:

with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ] ) as p: code_to_profile_0() // turn off collection of all CUDA activity p.toggle_collection_dynamic(False, [torch.profiler.ProfilerActivity.CUDA]) code_to_profile_1() // turn on collection of all CUDA activity p.toggle_collection_dynamic(True, [torch.profiler.ProfilerActivity.CUDA]) code_to_profile_2() print(p.key_averages().table( sort_by="self_cuda_time_total", row_limit=-1))

class torch.profiler.profile(*, activities=None, schedule=None, on_trace_ready=None, record_shapes=False, profile_memory=False, with_stack=False, with_flops=False, with_modules=False, experimental_config=None, execution_trace_observer=None, acc_events=False, use_cuda=None, custom_trace_id_callback=None)[source][source]

Profiler context manager.

Parameters

Note

Use schedule() to generate the callable schedule. Non-default schedules are useful when profiling long training jobs and allow the user to obtain multiple traces at the different iterations of the training process. The default schedule simply records all the events continuously for the duration of the context manager.

Note

Use tensorboard_trace_handler() to generate result files for TensorBoard:

on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)

After profiling, result files can be found in the specified directory. Use the command:

tensorboard --logdir dir_name

to see the results in TensorBoard. For more information, seePyTorch Profiler TensorBoard Plugin

Note

Enabling shape and stack tracing results in additional overhead. When record_shapes=True is specified, profiler will temporarily hold references to the tensors; that may further prevent certain optimizations that depend on the reference count and introduce extra tensor copies.

Examples:

with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ] ) as p: code_to_profile() print(p.key_averages().table( sort_by="self_cuda_time_total", row_limit=-1))

Using the profiler’s schedule, on_trace_ready and step functions:

Non-default profiler schedule allows user to turn profiler on and off

on different iterations of the training loop;

trace_handler is called every time a new trace becomes available

def trace_handler(prof): print(prof.key_averages().table( sort_by="self_cuda_time_total", row_limit=-1)) # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")

with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ],

# In this example with wait=1, warmup=1, active=2, repeat=1,
# profiler will skip the first step/iteration,
# start warming up on the second, record
# the third and the forth iterations,
# after which the trace will become available
# and on_trace_ready (when set) is called;
# the cycle repeats starting with the next step

schedule=torch.profiler.schedule(
    wait=1,
    warmup=1,
    active=2,
    repeat=1),
on_trace_ready=trace_handler
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
# used when outputting for tensorboard
) as p:
    for iter in range(N):
        code_iteration_to_profile(iter)
        # send a signal to the profiler that the next iteration has started
        p.step()

The following sample shows how to setup up an Execution Trace Observer (execution_trace_observer)

with torch.profiler.profile( ... execution_trace_observer=( ExecutionTraceObserver().register_callback("./execution_trace.json") ), ) as p: for iter in range(N): code_iteration_to_profile(iter) p.step()

You can also refer to test_execution_trace_with_kineto() in tests/profiler/test_profiler.py. Note: One can also pass any object satisfying the _ITraceObserver interface.

get_trace_id()[source][source]

Returns the current trace ID.

set_custom_trace_id_callback(callback)[source][source]

Sets a callback to be called when a new trace ID is generated.

step()[source][source]

Signals the profiler that the next profiling step has started.

class torch.profiler.ProfilerAction(value)[source][source]

Profiler actions that can be taken at the specified intervals

class torch.profiler.ProfilerActivity

Members:

CPU

XPU

MTIA

CUDA

HPU

PrivateUse1

property name

torch.profiler.schedule(*, wait, warmup, active, repeat=0, skip_first=0, skip_first_wait=0)[source][source]

Returns a callable that can be used as profiler schedule argument. The profiler will skip the first skip_first steps, then wait for wait steps, then do the warmup for the next warmup steps, then do the active recording for the next active steps and then repeat the cycle starting with wait steps. The optional number of cycles is specified with the repeat parameter, the zero value means that the cycles will continue until the profiling is finished.

The skip_first_wait parameter controls whether the first wait stage should be skipped. This can be useful if a user wants to wait longer than skip_first between cycles, but not for the first profile. For example, if skip_first is 10 and wait is 20, the first cycle will wait 10 + 20 = 30 steps before warmup if skip_first_wait is zero, but will wait only 10 steps if skip_first_wait is non-zero. All subsequent cycles will then wait 20 steps between the last active and warmup.

Return type

Callable

torch.profiler.tensorboard_trace_handler(dir_name, worker_name=None, use_gzip=False)[source][source]

Outputs tracing files to directory of dir_name, then that directory can be directly delivered to tensorboard as logdir.worker_name should be unique for each worker in distributed scenario, it will be set to ‘[hostname]_[pid]’ by default.

Intel Instrumentation and Tracing Technology APIs

torch.profiler.itt.is_available()[source][source]

Check if ITT feature is available or not

torch.profiler.itt.mark(msg)[source][source]

Describe an instantaneous event that occurred at some point.

Parameters

msg (str) – ASCII message to associate with the event.

torch.profiler.itt.range_push(msg)[source][source]

Pushes a range onto a stack of nested range span. Returns zero-based depth of the range that is started.

Parameters

msg (str) – ASCII message to associate with range

torch.profiler.itt.range_pop()[source][source]

Pops a range off of a stack of nested range spans. Returns the zero-based depth of the range that is ended.