Pipeline — NVIDIA DALI (original) (raw)

In DALI, any data processing task has a central object called Pipeline. Pipeline object is an instance of nvidia.dali.Pipeline or a derived class. Pipeline encapsulates the data processing graph and the execution engine.

You can define a DALI Pipeline in the following ways:

  1. By implementing a function that uses DALI operators inside and decorating it with the pipeline_def() decorator.
  2. By instantiating Pipeline object directly, building the graph and setting the pipeline outputs with Pipeline.set_outputs().
  3. By inheriting from Pipeline class and overriding Pipeline.define_graph() (this is the legacy way of defining DALI Pipelines).

Data Processing Graphs#

DALI pipeline is represented as a graph of operations. There are two kinds of nodes in the graph:

The data nodes can be transformed by calling operator functions. They also support Python-style indexing and can be incorporated inmathematical expressions.

Example:

@pipeline_def # create a pipeline with processing graph defined by the function below def my_pipeline(): """ Create a pipeline which reads images and masks, decodes the images and returns them. """ img_files, labels = fn.readers.file(file_root="image_dir", seed=1) mask_files, _ = fn.readers.file(file_root="mask_dir", seed=1) images = fn.decoders.image(img_files, device="mixed") masks = fn.decoders.image(mask_files, device="mixed") return images, masks, labels

pipe = my_pipeline(batch_size=4, num_threads=2, device_id=0)

The resulting graph is:

_images/two_readers.svg

Important

The pipeline definition function is excuted only once, when the pipeline is built, and typically returns a dali.DataNode object or a tuple of thereof. For convenience, it’s possible to return other types, such as NumPy arrays, but those are treated as constants and evaluated only once.

Processing Graph Structure#

DALI pipelines are executed in stages. The stages correspond to the device parameter that can be specified for the operator, and are executed in following order:

  1. 'cpu' - operators that accept CPU inputs and produce CPU outputs.
  2. 'mixed' - operators that accept CPU inputs and produce GPU outputs, for example nvidia.dali.fn.decoders.image().
  3. 'gpu' - operators that accept GPU inputs and produce GPU outputs.

Data produced by a CPU operator may be explicitly copied to the GPU by calling .gpu()on a DataNode (an output of a DALI operator).

Data that has been produced by a later stage cannot be consumed by an operator executing in an earlier stage.

Most DALI operators accept additional keyword arguments used to parametrize their behavior. Those named keyword arguments (which are distinct from the positional inputs) can be:

In the case of argument inputs, passing output of one operator as a named keyword argumentof other operator will establish a connection in the processing graph.

Those parameters will be computed as a part of DALI pipeline graph every iteration and for every sample. Keep in mind, that only CPU operators can be used as argument inputs.

Example:

@pipeline_def def my_pipeline(): img_files, labels = fn.readers.file(file_root='image_dir', device='cpu') # images is GPU data (result of Mixed operator) images = fn.decoders.image(img_files, device='mixed') # coin_flip must be on CPU so the flip_params can be used as argument input flip_param = fn.random.coin_flip(device='cpu') # images is input (GPU) and flip_param is argument input (CPU) flipped = fn.flip(images, horizontal=flip_param, device='gpu') # labels is explicitly marked for transfer to GPU, flipped is already GPU return flipped, labels.gpu()

pipe = my_pipeline(batch_size=4, num_threads=2, device_id=0)

Note

If the device parameter is not specified, it is selected automatically based on the placement of the inputs. If there is at least one GPU input, the device='gpu' is assumed, otherwise 'cpu' is used.

The example above adds device parameter explicitly for clarity, but it would work the same if only device='mixed' was specified for fn.decoders.image.

Current Pipeline#

Subgraphs that do not contribute to the pipeline output are automatically pruned. If an operator has side effects (e.g. PythonFunction operator family), it cannot be invoked without setting the current pipeline. Current pipeline is set implicitly when the graph is defined inside derived pipelines’ Pipeline.define_graph() method. Otherwise, it can be set using context manager (with statement):

pipe = dali.Pipeline(batch_size=N, num_threads=3, device_id=0) with pipe: src = dali.ops.ExternalSource(my_source, num_outputs=2) a, b = src() pipe.set_outputs(a, b)

When creating a pipeline with pipeline_def(), the function which defines the pipeline is executed within the scope of the newly created pipeline. The following example is equivalent to the previous one:

@dali.pipeline_def(batch_size=N, num_threads=3, device_id=0) def my_pipe(my_source): return dali.fn.external_source(my_source, num_outputs=2)

pipe = my_pipe(my_source)

Pipeline Decorator#

@nvidia.dali.pipeline_def(fn=None, *, enable_conditionals=False, **pipeline_kwargs)#

Decorator that converts a graph definition function into a DALI pipeline factory.

A graph definition function is a function that returns intended pipeline outputs. You can decorate this function with @pipeline_def:

@pipeline_def def my_pipe(flip_vertical, flip_horizontal): ''' Creates a DALI pipeline, which returns flipped and original images ''' data, _ = fn.readers.file(file_root=images_dir) img = fn.decoders.image(data, device="mixed") flipped = fn.flip(img, horizontal=flip_horizontal, vertical=flip_vertical) return flipped, img

The decorated function returns a DALI Pipeline object:

pipe = my_pipe(True, False)

pipe.run() # the pipeline is not configured properly yet

A pipeline requires additional parameters such as batch size, number of worker threads, GPU device id and so on (see nvidia.dali.Pipeline() for a complete list of pipeline parameters). These parameters can be supplied as additional keyword arguments, passed to the decorated function:

pipe = my_pipe(True, False, batch_size=32, num_threads=1, device_id=0)

The pipeline is properly configured, we can run it now. The outputs from the original function became the outputs of the Pipeline:

flipped, img = pipe.run()

When some of the pipeline parameters are fixed, they can be specified by name in the decorator:

@pipeline_def(batch_size=42, num_threads=3) def my_pipe(flip_vertical, flip_horizontal): ...

Any Pipeline constructor parameter passed later when calling the decorated function will override the decorator-defined params:

@pipeline_def(batch_size=32, num_threads=3) def my_pipe(): data = fn.external_source(source=my_generator) return data

pipe = my_pipe(batch_size=128) # batch_size=128 overrides batch_size=32

Warning

The arguments of the function being decorated can shadow pipeline constructor arguments - in which case there’s no way to alter their values.

Note

Using **kwargs (variadic keyword arguments) in graph-defining function is not allowed. They may result in unwanted, silent hijacking of some arguments of the same name by Pipeline constructor. Code written this way would cease to work with future versions of DALI when new parameters are added to the Pipeline constructor.

To access any pipeline arguments within the body of a @pipeline_def function, the functionnvidia.dali.Pipeline.current() can be used:

@pipeline_def() def my_pipe(): pipe = Pipeline.current() batch_size = pipe.batch_size num_threads = pipe.num_threads ...

pipe = my_pipe(batch_size=42, num_threads=3) ...

Keyword Arguments:

enable_conditionals (bool , optional) – Enable support for conditional execution of DALI operators using if statements in the pipeline definition, by default False.

Conditional Execution#

DALI allows to execute operators conditionally for selected samples within the batch usingif statements. To enable this feature use the@pipeline_def decorator to define the pipeline and set enable_conditionals to True.

Every if statement that have a DataNode() as a condition will be recognized as DALI conditional statement.

For example, this pipeline rotates each image with probability of 25% by a random angle between 10 and 30 degrees:

@pipeline_def(enable_conditionals=True) def random_rotate(): jpegs, _ = fn.readers.file(device="cpu", file_root=images_dir) images = fn.decoders.image(jpegs, device="mixed") do_rotate = fn.random.coin_flip(probability=0.25, dtype=DALIDataType.BOOL) if do_rotate: result = fn.rotate(images, angle=fn.random.uniform(range=(10, 30)), fill_value=0) else: result = images return result

The semantics of DALI conditionals can be understood as if the code processed one sample at a time.

The condition must be represented by scalar samples - that is have a 0-d shape. It can be either boolean or any numerical type supported by DALI - in the latter case, non-zero values are considered True and zero values considered False, in accordance with typical Python semantics.

Additionally, logical expressions and, or, and not can be used onDataNode(). The first two are restricted to boolean inputs, notallows the same input types as if statement condition. Logical expression follow the shortcutting rules when they are evaluated.

You can read more in the conditional tutorial.

Preventing AutoGraph conversion#

@nvidia.dali.pipeline.do_not_convert#

Decorator that suppresses the conversion of a function by AutoGraph.

In conditional mode, DALI uses a fork ofTensorFlow’s AutoGraphto transform the code, enabling us to rewrite and detect the if statements, so they can be used in processing the DALI pipeline.

The AutoGraph conversion is applied to any top-level function or method called within the pipeline definition (as well as the pipeline definition itself). When a function is converted, all functions defined within its syntactical scope are also converted. The rewriting, among other effects, makes these functions non-serializable.

To stop a function from being converted, its top-level encompassing function must be marked with this decorator. This may sometimes require refactoring the function to outer scope.

Parallel mode of external source (parallel=True), requires that its source parameter is serializable. To prevent the rewriting of thesource, the functions that are used to create the source, should be decorated with @do_not_convert.

Note

Only functions that do not process DataNode (so do not use DALI operators) should be marked with this decorator.

For example:

from nvidia.dali import pipeline_def, fn

@pipeline_def(enable_conditionals=True) def pipe():

def source_factory(size):
    def source_fun(sample_info):
        return np.full(size, sample_info.iter_idx)
    return source_fun

source = source_factory(size=(2, 1))
return fn.external_source(source=source, parallel=True, batch=False)

Should be converted into:

from nvidia.dali import pipeline_def, fn from nvidia.dali.pipeline import do_not_convert

@do_not_convert def source_factory(size): def source_fun(sample_info): return np.full(size, sample_info.iter_idx) return source_fun

@pipeline_def(enable_conditionals=True) def pipe(): source = source_factory(size=(2, 1)) return fn.external_source(source=source, parallel=True, batch=False)

The source_factory must be factored out, otherwise it would be converted as a part of pipeline definition. As we are interested in preventing the AutoGraph conversion ofsource_fun we need to decorate its top-level encompassing function.

Note

If a function is declared outside of the pipeline definition, and is passed as a parameter, but not directly invoked within the pipeline definition, it will not be converted. In such case, a callback passed toexternal source operator,python function operator family orNumba function operator is not considered as being directly invoked in pipeline definition. Such callback is executed when the pipeline is run, so after the pipeline is defined and built.

For example:

from nvidia.dali import pipeline_def, fn

def source_fun(sample_info): return np.full((2, 2), sample_info.iter_idx)

@pipeline_def(enable_conditionals=True) def pipe(): return fn.external_source(source=source_fun, batch=False)

The source_fun won’t be converted, as it is defined outside of pipeline definition and it is only passed via name to external source.

Pipeline class#

class nvidia.dali.Pipeline(batch_size=None, num_threads=None, device_id=None, seed=None, exec_pipelined=True, prefetch_queue_depth=2, exec_async=True, bytes_per_sample=0, set_affinity=False, max_streams=None, default_cuda_stream_priority=None, *, enable_memory_stats=False, enable_checkpointing=False, checkpoint=None, py_num_workers=1, py_start_method='fork', py_callback_pickler=None, output_dtype=None, output_ndim=None, output_layout=None, exec_dynamic=False, experimental_exec_dynamic=None)#

Pipeline class is the base of all DALI data pipelines. The pipeline encapsulates the data processing graph and the execution engine.

Parameters:

__enter__()#

Safely sets the pipeline as current. Current pipeline is required to call operators with side effects or without outputs. Examples of such operators are PythonFunction (potential side effects) or DumpImage(no output).

Any dangling operator can be marked as having side effects if it’s marked with preserve=True, which can be useful for debugging - otherwise operator which does not contribute to the pipeline output is removed from the graph.

To manually set new (and restore previous) current pipeline, use push_current()and pop_current(), respectively.

__exit__(exception_type, exception_value, traceback)#

Safely restores previous pipeline.

add_sink(edge)#

Marks an edge as a data sink, preventing it from being pruned, even if it’s not connected to the pipeline output.

property batch_size#

Batch size.

build()#

Build the pipeline (optional step).

Instantiates the pipeline’s backend objects and starts processing threads. If the pipeline uses multi-processing external_source, the worker processes are also started. In most cases, there’s no need to manually call build. When multi-processing is used, it may be necessary to call build() or start_py_workers() before the main process makes any interaction with the GPU. If needed, the build() can used before running the pipeline to separate the graph building and the processing steps.

If the pipeline requires a GPU (it contains any “cpu” or “mixed” operators or has GPU outputs) and no device_id was specified at construction, the current CUDA device (according to cudaGetDevice) will be used.

Pipeline is automatically built when it is:

checkpoint(filename=None)#

Returns the pipeline’s state as a serialized Protobuf string.

Additionally, if filename is specified, the serialized checkpoint will be written to the specified file. The file contents will be overwritten.

The same pipeline can be later rebuilt with the saved checkpoint passed as a checkpointparameter to resume execution from the saved iteration.

More details can be found inthis documentation section.

Parameters:

filename (str) – The file that the serialized pipeline will be written to.

property cpu_queue_size#

The number of iterations processed ahead by the CPU stage.

static current()#

Returns the instance of the current pipeline set by push_current().

property default_cuda_stream_priority#

Deprecated; always 0.

define_graph()#

This function is defined by the user to construct the graph of operations for their pipeline.

It returns a list of outputs created by calling DALI Operators.

classmethod deserialize(serialized_pipeline=None, filename=None, **kwargs)#

Deserialize and build pipeline.

Deserialize pipeline, previously serialized with serialize() method.

Returned pipeline is already built.

Alternatively, additional arguments can be passed, which will be used when instantiating the pipeline. Refer to Pipeline constructor for full list of arguments. By default, the pipeline will be instantiated with the arguments from serialized pipeline.

Note, that serialized_pipeline and filename parameters are mutually exclusive

Parameters:

Return type:

Deserialized and built pipeline.

deserialize_and_build(serialized_pipeline)#

Deserialize and build the pipeline given in serialized form.

Parameters:

serialized_pipeline (str) – Serialized pipeline.

property device_id#

Id of the GPU used by the pipeline or None, if not set.

If the pipeline requires a GPU but none was specified at construction, the current GPU (according to CUDA runtime) will be assigned once the pipeline is built.

empty()#

If there is any work scheduled in the pipeline but not yet consumed

enable_api_check(enable)#

Allows to enable or disable API check in the runtime

property enable_memory_stats#

If True, memory usage statistics are gathered.

epoch_size(name=None)#

Epoch size of a pipeline.

If the name parameter is None, returns a dictionary of pairs(reader name, epoch size for that reader). If the name parameter is not None, returns epoch size for that reader.

Parameters:

name (str , optional , default = None) – The reader which should be used to obtain epoch size.

property exec_async#

If true, asynchronous execution is used.

property exec_dynamic#

If true, the dynamic executor is used.

property exec_pipelined#

If true, pipeline execution model is used.

property exec_separated#

If True, there are separate prefetch queues for CPU and GPU stages.

executor_statistics()#

Returns provided pipeline executor statistics metadata as a dictionary. Each key in the dictionary is the operator name. To enable it use executor_statistics

Available metadata keys for each operator:

Note

Executor statistics are not available when using exec_dynamic=True.

external_source_shm_statistics()#

Returns parallel external source’s statistics regarding shared memory consumption. The returned dictionary contains following keys:

feed_input(data_node, data, layout=None, cuda_stream=None, use_copy_kernel=False)#

Pass a multidimensional array or DLPack (or a list thereof) to an eligible operator.

The operators that may be provided with data using this function are the input operators (i.e. everything in fn.inputs module) and the fn.external_source().

In the case of the GPU input, the data must be modified on the same stream as the one used by feed_input. See cuda_stream parameter for details.

In order to avoid stalls, the data should be provided ahead of time prefetch_queue_depthtimes.

Parameters:

property gpu_queue_size#

The number of iterations processed ahead by the GPU stage.

property is_restored_from_checkpoint#

If True, this pipeline was restored from checkpoint.

iter_setup()#

A deprecated method of providing the pipeline with external inputs.

This function can be overridden by a user-defined pipeline to perform any needed setup for each iteration. For example, one can use this function to feed the input data from NumPy arrays.

This method is deprecated and its use is discouraged. Newer execution models may be incompatible with this method of providing data to the pipeline. Use source argument in external_source instead, where possible.

property max_batch_size#

Maximum batch size.

property max_streams#

Deprecated, unused; returns -1.

property num_outputs#

Number of pipeline outputs.

property num_threads#

Number of CPU threads used by this pipeline.

output_dtype()#

Data types expected at the outputs.

output_ndim()#

Number of dimensions expected at the outputs.

output_stream()#

Returns the internal CUDA stream on which the outputs are produced.

outputs(cuda_stream=None)#

Returns the outputs of the pipeline and releases previous buffer.

If the pipeline is executed asynchronously, this function blocks until the results become available. It rises StopIteration if data set reached its end - usually when iter_setup cannot produce any more data.

Parameters:

cuda_stream (optional, cudaStream_t or an object convertible to cudaStream_t,) – e.g. cupy.cuda.Stream, torch.cuda.StreamThe stream to which the returned TensorLists are bound. Defaults to None, which means that the outputs are synchronized with the host. Works only with pipelines constructed with exec_dynamic=True.

Return type:

A list of TensorList objects for respective pipeline outputs

static pop_current()#

Restores previous pipeline as current. Complementary to push_current().

property prefetch_queue_depth#

Depth (or depths) of the prefetch queue, as specified in the __init__ arguments.

static push_current(pipeline)#

Sets the pipeline as current and stores the previous current pipeline on stack. To restore previous pipeline as current, use pop_current().

To make sure that the pipeline is properly restored in case of exception, use context manager (with my_pipeline:).

Current pipeline is required to call operators with side effects or without outputs. Examples of such operators are PythonFunction (potential side effects) or DumpImage(no output).

Any dangling operator can be marked as having side effects if it’s marked with preserve=True, which can be useful for debugging - otherwise operator which does not contribute to the pipeline output is removed from the graph.

property py_num_workers#

The number of Python worker processes used by parallel `external_source`.

property py_start_method#

The method of launching Python worker processes used by parallel `external_source`.

reader_meta(name=None)#

Returns provided reader metadata as a dictionary. If no name is provided if provides a dictionary with data for all readers as {reader_name : meta}

Available metadata keys:

epoch_size: raw epoch size

epoch_size_padded: epoch size with the padding at the end to be divisible by

the number of shards

number_of_shards: number of shards

shard_id: shard id of given reader

pad_last_batch: if given reader should pad last batch

stick_to_shard: if given reader should stick to its shard

Parameters:

name (str , optional , default = None) – The reader which should be used to obtain shards_number.

release_outputs()#

Release buffers returned by share_outputs calls.

It helps in case when output call result is consumed (copied) and buffers can be marked as free before the next call to share_outputs. It provides the user with better control about when he wants to run the pipeline, when he wants to obtain the resulting buffers and when they can be returned to DALI pool when the results have been consumed. Needs to be used together with schedule_run()and share_outputs()Should not be mixed with run() in the same pipeline.

Note

When using dynamic executor (exec_dynamic=True), the buffers are not invalidated.

reset()#

Resets pipeline iterator

If pipeline iterator reached the end then reset its state to the beginning.

run(cuda_stream=None, /, **pipeline_inputs)#

Run the pipeline and return the result on the specified CUDA stream.

If the pipeline was created with exec_pipelined option set to True, this function will also start prefetching the next iteration for faster execution. Should not be mixed with schedule_run() in the same pipeline,share_outputs() andrelease_outputs()

The pipeline is built if no explicit call to build was made previously.

Parameters:

Return type:

A tuple of TensorList objects for respective pipeline outputs

save_graph_to_dot_file(filename, *, show_tensors=False, show_ids=None, use_colors=False)#

Saves the pipeline graph to a file.

Parameters:

schedule_run()#

Run the pipeline without returning the resulting buffers.

If the pipeline was created with exec_pipelined option set to True, this function will also start prefetching the next iteration for faster execution. It provides better control to the users about when they want to run the pipeline, when they want to obtain resulting buffers and return them to DALI buffer pool when the results have been consumed. Needs to be used together with release_outputs()and share_outputs(). Should not be mixed with run() in the same pipeline.

The pipeline is built if no explicit call to build was made previously.

property seed#

Random seed used in the pipeline or None, if seed is not fixed.

serialize(define_graph=None, filename=None)#

Serialize the pipeline definition to a Protobuf string.

Note

This function doesn’t serialize the pipeline’s internal state. Usecheckpointing to achieve that.

Additionally, you can pass a file name, so that serialized pipeline will be written there. The file contents will be overwritten.

Parameters:

property set_affinity#

If True, worker threads are bound to CPU cores.

set_outputs(*output_data_nodes)#

Set the outputs of the pipeline.

Use of this function is an alternative to overriding define_graph in a derived class.

Parameters:

*output_data_nodes (unpacked list of DataNode objects) – The outputs of the pipeline

Returns the outputs of the pipeline.

Main difference to outputs()is that share_outputs doesn’t release returned buffers, release_outputs need to be called for that. If the pipeline is executed asynchronously, this function blocks until the results become available. It provides the user with better control about when he wants to run the pipeline, when he wants to obtain the resulting buffers and when they can be returned to DALI pool when the results have been consumed. Needs to be used together with release_outputs()and schedule_run()Should not be mixed with run() in the same pipeline.

Parameters:

cuda_stream (optional, cudaStream_t or an object convertible to cudaStream_t,) – e.g. cupy.cuda.Stream, torch.cuda.StreamThe stream to which the returned TensorLists are bound. Defaults to None, which means that the outputs are synchronized with the host. Works only with pipelines constructed with exec_dynamic=True.

Returns:

start_py_workers()#

Start Python workers (that will run ExternalSource callbacks). You need to call start_py_workers() before you call any functionality that creates or acquires CUDA context when using fork to start Python workers (py_start_method="fork"). It is called automatically byPipeline.build() method when such separation is not necessary.

If you are going to build more than one pipeline that starts Python workers by forking the process then you need to call start_py_workers() method on all those pipelines before calling any method that builds or runs the pipeline (see build() for details), as building acquires CUDA context for current process.

The same applies to using any other functionality that would create CUDA context - for example, initializing a framework that uses CUDA or creating CUDA tensors with it. You need to call start_py_workers() before you call such functionality when using py_start_method="fork".

Forking a process that has a CUDA context is unsupported and may lead to unexpected errors.

If you use the method you cannot specify define_graph argument when calling build().

DataNode#

class nvidia.dali.pipeline.DataNode(name, device='cpu', source=None)#

This class is a symbolic representation of a TensorList and is used at graph definition stage. It does not carry actual data, but is used to define the connections between operators and to specify the pipeline outputs. See documentation for Pipeline for details.

DataNode objects can be passed to DALI operators as inputs (and some of the named keyword arguments) but they also provide arithmetic operations which implicitly create appropriate operators that perform the expressions.

property(key, *, device='cpu')#

Returns a metadata property associated with a DataNode

Parameters:

shape(*, dtype=None, device='cpu')#

Returns the run-time shapes of this DataNode as a new DataNode

Parameters:

source_info(*, device='cpu')#

Returns the “source_info” property. Equivalent to self.meta(“source_info”).

Experimental Pipeline Features#

Some additional experimental features can be enabled via the special variant of the pipeline decorator.

@nvidia.dali.pipeline.experimental.pipeline_def(fn=None, *, enable_conditionals=False, **pipeline_kwargs)#

Variant of @pipeline_def decorator that enables additional experimental features. It has the same API as its non-experimental variant with the addition of the keyword arguments listed below.

Keyword Arguments:

Pipeline Debug Mode (experimental)#

Pipeline can be run in debug mode by replacing @nvidia.dali.pipeline_def decorator with its experimental variant @nvidia.dali.pipeline.experimental.pipeline_def and setting parameterdebug to True. It allows you to access and modify data inside the pipeline execution graph, as well as use non-DALI data types as inputs to the DALI operators.

In this mode outputs of operators are of type DataNodeDebug which is an equivalent toDataNode in the standard mode. You can perform the same operations on objects of typeDataNodeDebug as on DataNode, that includes arithmetic operations.

Use .get() to access data associated with the DataNodeDebug object during current execution of Pipeline.run():

@nvidia.dali.pipeline.experimental.pipeline_def(debug=True) def my_pipe(): data, _ = fn.readers.file(file_root=images_dir) img = fn.decoders.image(data) print(np.array(img.get()[0])) ...

Use non-DALI data types (e.g. NumPy ndarray, PyTorch Tensor) directly with DALI operators:

@nvidia.dali.pipeline.experimental.pipeline_def(batch_size=8, debug=True) def my_pipe(): img = [np.random.rand(640, 480, 3) for _ in range(8)] output = fn.flip(img) ...

Notice#

Warning

Using debug mode will drastically worsen performance of your pipeline. Use it only for debugging purposes.

Note

This feature is experimental and its API might change without notice.