Stream — PyTorch 2.7 documentation (original) (raw)

class torch.Stream(device, *, priority)

An in-order queue of executing the respective tasks asynchronously in first in first out (FIFO) order. It can control or synchronize the execution of other Stream or block the current host thread to ensure the correct task sequencing. It supports with statement as a context manager to ensure the operators within the with block are running on the corresponding stream.

See in-depth description of the CUDA behavior at CUDA semantics for details on the exact semantic that applies to all devices.

Parameters

Returns

An torch.Stream object.

Return type

Stream

Example:

with torch.Stream(device='cuda') as s_cuda: a = torch.randn(10, 5, device='cuda') b = torch.randn(5, 10, device='cuda') c = torch.mm(a, b)

query() → bool

Check if all the work submitted has been completed.

Returns

A boolean indicating if all kernels in this stream are completed.

Return type

bool

Example:

s_cuda = torch.Stream(device='cuda') s_cuda.query() True

record_event(event) → Event

Record an event. En-queuing it into the Stream to allow further synchronization from the current point in the FIFO queue.

Parameters

event (torch.Event, optional) – event to record. If not given, a new one will be allocated.

Returns

Recorded event.

Return type

Event

Example:

s_cuda = torch.Stream(device='cuda') e_cuda = s_cuda.record_event()

synchronize() → None

Wait for all the kernels in this stream to complete.

Example:

s_cuda = torch.Stream(device='cuda') s_cuda.synchronize()

wait_event(event) → None

Make all future work submitted to the stream wait for an event.

Parameters

event (torch.Event) – an event to wait for.

Example:

s1_cuda = torch.Stream(device='cuda') s2_cuda = torch.Stream(device='cuda') e_cuda = s1_cuda.record_event() s2_cuda.wait_event(e_cuda)

wait_stream(stream) → None

Synchronize with another stream. All future work submitted to this stream will wait until all kernels already submitted to the given stream are completed.

Parameters

stream (torch.Stream) – a stream to synchronize.

Example:

s1_cuda = torch.Stream(device='cuda') s2_cuda = torch.Stream(device='cuda') s2_cuda.wait_stream(s1_cuda)