Examples — Intel&#174 Extension for PyTorch* 2.7.0+cpu documentation (original) (raw)

These examples will guide you through using the Intel® Extension for PyTorch* on Intel CPUs.

You can also refer to the Features section to get the examples and usage instructions related to particular features.

The source code for these examples, as well as the feature examples, can be found in the GitHub source tree under the examples directory.

Prerequisites: Before running these examples, please note the following:

Python

The optimize function of Intel® Extension for PyTorch* applies optimizations to the model, bringing additional performance boosts. For both computer vision workloads and NLP workloads, we recommend applying the optimize function against the model object.

Float32

Eager Mode

Resnet50

Note: You need to install torchvision Python package to run the following example.

import torch import torchvision.models as models

model = models.resnet50(weights="ResNet50_Weights.DEFAULT") model.eval() data = torch.rand(128, 3, 224, 224)

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex

model = ipex.optimize(model) ###################################################### # noqa F401

with torch.no_grad(): model(data)

print("Execution finished")

BERT

Note: You need to install transformers Python package to run the following example.

import torch from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased") model.eval()

vocab_size = model.config.vocab_size batch_size = 128 seq_length = 512 data = torch.randint(vocab_size, size=[batch_size, seq_length])

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex

model = ipex.optimize(model) ###################################################### # noqa F401

with torch.no_grad(): model(data)

print("Execution finished")

TorchScript Mode

We recommend using Intel® Extension for PyTorch* with TorchScript for further optimizations.

Resnet50

Note: You need to install torchvision Python package to run the following example.

import torch import torchvision.models as models

model = models.resnet50(weights="ResNet50_Weights.DEFAULT") model.eval() data = torch.rand(128, 3, 224, 224)

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex

model = ipex.optimize(model) ###################################################### # noqa F401

with torch.no_grad(): d = torch.rand(128, 3, 224, 224) model = torch.jit.trace(model, d) model = torch.jit.freeze(model)

model(data)

print("Execution finished")

BERT

Note: You need to install transformers Python package to run the following example.

import torch from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased") model.eval()

vocab_size = model.config.vocab_size batch_size = 128 seq_length = 512 data = torch.randint(vocab_size, size=[batch_size, seq_length])

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex

model = ipex.optimize(model) ###################################################### # noqa F401

with torch.no_grad(): d = torch.randint(vocab_size, size=[batch_size, seq_length]) model = torch.jit.trace(model, (d,), check_trace=False, strict=False) model = torch.jit.freeze(model)

model(data)

print("Execution finished")

TorchDynamo Mode (Beta, NEW feature from 2.0.0)

Resnet50

Note: You need to install torchvision Python package to run the following example.

import torch import torchvision.models as models

model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) model.eval() data = torch.rand(128, 3, 224, 224)

Beta Feature

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex

model = ipex.optimize(model, weights_prepack=False) model = torch.compile(model, backend="ipex") ###################################################### # noqa F401

with torch.no_grad(): model(data)

print("Execution finished")

BERT

Note: You need to install transformers Python package to run the following example.

import torch from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased") model.eval()

vocab_size = model.config.vocab_size batch_size = 128 seq_length = 512 data = torch.randint(vocab_size, size=[batch_size, seq_length])

Beta Feature

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex

model = ipex.optimize(model, weights_prepack=False) model = torch.compile(model, backend="ipex") ###################################################### # noqa F401

with torch.no_grad(): model(data)

print("Execution finished")

Note: In TorchDynamo mode, since the native PyTorch operators like aten::convolution and aten::linear are well supported and optimized in ipex backend, we need to disable weights prepacking by setting weights_prepack=False in ipex.optimize().

BFloat16

The optimize function works for both Float32 and BFloat16 data type. For BFloat16 data type, set the dtype parameter to torch.bfloat16. We recommend using Auto Mixed Precision (AMP) with BFloat16 data type.

Eager Mode

Resnet50

Note: You need to install torchvision Python package to run the following example.

import torch import torchvision.models as models

model = models.resnet50(weights="ResNet50_Weights.DEFAULT") model.eval() data = torch.rand(128, 3, 224, 224)

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex

model = ipex.optimize(model, dtype=torch.bfloat16) ###################################################### # noqa F401

Note: bf16 inference requires amp.autocast() context # noqa F401

with torch.no_grad(), torch.cpu.amp.autocast(): model(data)

print("Execution finished")

BERT

Note: You need to install transformers Python package to run the following example.

import torch from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased") model.eval()

vocab_size = model.config.vocab_size batch_size = 128 seq_length = 512 data = torch.randint(vocab_size, size=[batch_size, seq_length])

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex

model = ipex.optimize(model, dtype=torch.bfloat16) ###################################################### # noqa F401

Note: bf16 inference requires amp.autocast() context # noqa F401

with torch.no_grad(), torch.cpu.amp.autocast(): model(data)

print("Execution finished")

TorchScript Mode

We recommend using Intel® Extension for PyTorch* with TorchScript for further optimizations.

Resnet50

Note: You need to install torchvision Python package to run the following example.

import torch import torchvision.models as models

model = models.resnet50(weights="ResNet50_Weights.DEFAULT") model.eval() data = torch.rand(128, 3, 224, 224)

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex

model = ipex.optimize(model, dtype=torch.bfloat16) ###################################################### # noqa F401

Note: bf16 inference requires amp.autocast() context # noqa F401

with torch.no_grad(), torch.cpu.amp.autocast(): model = torch.jit.trace(model, torch.rand(128, 3, 224, 224)) model = torch.jit.freeze(model)

model(data)

print("Execution finished")

BERT

Note: You need to install transformers Python package to run the following example.

import torch from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased") model.eval()

vocab_size = model.config.vocab_size batch_size = 128 seq_length = 512 data = torch.randint(vocab_size, size=[batch_size, seq_length])

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex

model = ipex.optimize(model, dtype=torch.bfloat16) ###################################################### # noqa F401

Note: bf16 inference requires amp.autocast() context # noqa F401

with torch.no_grad(), torch.cpu.amp.autocast(): d = torch.randint(vocab_size, size=[batch_size, seq_length]) model = torch.jit.trace(model, (d,), check_trace=False, strict=False) model = torch.jit.freeze(model)

model(data)

print("Execution finished")

TorchDynamo Mode (Beta, NEW feature from 2.0.0)

Resnet50

Note: You need to install torchvision Python package to run the following example.

import torch import torchvision.models as models

model = models.resnet50(weights="ResNet50_Weights.DEFAULT") model.eval() data = torch.rand(128, 3, 224, 224)

Beta Feature

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex

model = ipex.optimize(model, dtype=torch.bfloat16, weights_prepack=False) model = torch.compile(model, backend="ipex") ###################################################### # noqa F401

Note: bf16 inference requires amp.autocast() context # noqa F401

with torch.no_grad(), torch.cpu.amp.autocast(): model(data)

print("Execution finished")

BERT

Note: You need to install transformers Python package to run the following example.

import torch from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased") model.eval()

vocab_size = model.config.vocab_size batch_size = 128 seq_length = 512 data = torch.randint(vocab_size, size=[batch_size, seq_length])

Beta Feature

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex

model = ipex.optimize(model, dtype=torch.bfloat16, weights_prepack=False) model = torch.compile(model, backend="ipex") ###################################################### # noqa F401

Note: bf16 inference requires amp.autocast() context # noqa F401

with torch.no_grad(), torch.cpu.amp.autocast(): model(data)

print("Execution finished")

Fast Bert (Prototype)

Note: You need to install transformers Python package to run the following example.

import torch from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased", attn_implementation="eager") model.eval()

vocab_size = model.config.vocab_size batch_size = 1 seq_length = 512 data = torch.randint(vocab_size, size=[batch_size, seq_length]) torch.manual_seed(43)

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex

model = ipex.fast_bert(model, dtype=torch.bfloat16) ###################################################### # noqa F401

with torch.no_grad(): model(data)

print("Execution finished")

INT8

Starting from Intel® Extension for PyTorch* 1.12.0, quantization feature supports both static and dynamic modes.

Static Quantization

Calibration

Please follow the steps below to perform calibration for static quantization:

  1. Import intel_extension_for_pytorch as ipex.
  2. Import prepare and convert from intel_extension_for_pytorch.quantization.
  3. Instantiate a config object from torch.ao.quantization.QConfig to save configuration data during calibration.
  4. Prepare model for calibration.
  5. Perform calibration against dataset.
  6. Invoke ipex.quantization.convert function to apply the calibration configure object to the fp32 model object to get an INT8 model.
  7. Save the INT8 model into a pt file.

Note: You need to install torchvision Python package to run the following example.

import torch

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex from intel_extension_for_pytorch.quantization import prepare, convert

###################################################### # noqa F401

Example Model ##### # noqa F401

import torchvision.models as models

model = models.resnet50(weights="ResNet50_Weights.DEFAULT") model.eval() data = torch.rand(128, 3, 224, 224) ######################### # noqa F401

qconfig_mapping = ipex.quantization.default_static_qconfig_mapping

Alternatively, define your own qconfig_mapping:

from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig, QConfigMapping

qconfig = QConfig(

activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),

weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))

qconfig_mapping = QConfigMapping().set_global(qconfig)

prepared_model = prepare(model, qconfig_mapping, example_inputs=data, inplace=False)

Example Dataloader ##### # noqa F401

import torchvision

DOWNLOAD = True DATA = "datasets/cifar10/"

transform = torchvision.transforms.Compose( [ torchvision.transforms.Resize((224, 224)), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) train_dataset = torchvision.datasets.CIFAR10( root=DATA, train=True, transform=transform, download=DOWNLOAD, ) calibration_data_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=128 )

with torch.no_grad(): for batch_idx, (d, target) in enumerate(calibration_data_loader): print(f"calibrated on batch {batch_idx} out of {len(calibration_data_loader)}") prepared_model(d) ############################## # noqa F401

converted_model = convert(prepared_model) with torch.no_grad(): traced_model = torch.jit.trace(converted_model, data) traced_model = torch.jit.freeze(traced_model)

traced_model.save("static_quantized_model.pt")

print("Saved model to: static_quantized_model.pt")

Deployment

For deployment, the INT8 model is loaded from the local file and can be used directly for sample inference.

Follow the steps below:

  1. Import intel_extension_for_pytorch as ipex.
  2. Load the INT8 model from the saved file.
  3. Run inference.

import torch

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex # noqa F401

###################################################### # noqa F401

model = torch.jit.load("static_quantized_model.pt") model.eval() model = torch.jit.freeze(model) data = torch.rand(128, 3, 224, 224)

with torch.no_grad(): model(data)

print("Execution finished")

Dynamic Quantization

Please follow the steps below to perform dynamic quantization:

  1. Import intel_extension_for_pytorch as ipex.
  2. Import prepare and convert from intel_extension_for_pytorch.quantization.
  3. Instantiate a config object from torch.ao.quantization.QConfig to save configuration data during calibration.
  4. Prepare model for quantization.
  5. Convert the model.
  6. Run inference to perform dynamic quantization.
  7. Save the INT8 model into a pt file.

Note: You need to install transformers Python package to run the following example.

import torch

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex from intel_extension_for_pytorch.quantization import prepare, convert

###################################################### # noqa F401

Example Model ##### # noqa F401

from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased") model.eval()

vocab_size = model.config.vocab_size batch_size = 128 seq_length = 512 data = torch.randint(vocab_size, size=[batch_size, seq_length]) ######################### # noqa F401

qconfig_mapping = ipex.quantization.default_dynamic_qconfig_mapping

Alternatively, define your own qconfig:

from torch.ao.quantization import PerChannelMinMaxObserver, PlaceholderObserver, QConfig, QConfigMapping

qconfig = QConfig(

activation = PlaceholderObserver.with_args(dtype=torch.float, is_dynamic=True),

weight = PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))

qconfig_mapping = QConfigMapping().set_global(qconfig)

prepared_model = prepare(model, qconfig_mapping, example_inputs=data)

converted_model = convert(prepared_model) with torch.no_grad(): traced_model = torch.jit.trace( converted_model, (data,), check_trace=False, strict=False ) traced_model = torch.jit.freeze(traced_model)

traced_model.save("dynamic_quantized_model.pt")

print("Saved model to: dynamic_quantized_model.pt")

Large Language Model (LLM)

Intel® Extension for PyTorch* provides dedicated optimization for running Large Language Models (LLM) faster. A set of data types are supported for various scenarios, including FP32, BF16, Smooth Quantization INT8, Weight Only Quantization INT8/INT4 (prototype).

Note: You need to install transformers==4.48.0 Python package to run the following example. In addition, you may need to log in your HuggingFace account to access the pretrained model files. Please refer to HuggingFace login.

FP32/BF16

import torch

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex

###################################################### # noqa F401 import argparse from transformers import ( AutoConfig, AutoModelForCausalLM, AutoTokenizer, )

args

parser = argparse.ArgumentParser("Generation script (fp32/bf16 path)", add_help=False) parser.add_argument( "--dtype", type=str, choices=["float32", "bfloat16"], default="float32", help="choose the weight dtype and whether to enable auto mixed precision or not", ) parser.add_argument( "--max-new-tokens", default=32, type=int, help="output max new tokens" ) parser.add_argument( "--prompt", default="What are we having for dinner?", type=str, help="input prompt" ) parser.add_argument("--greedy", action="store_true") parser.add_argument("--batch-size", default=1, type=int, help="batch size") args = parser.parse_args() print(args)

dtype

amp_enabled = True if args.dtype != "float32" else False amp_dtype = getattr(torch, args.dtype)

load model

model_id = "facebook/opt-125m" config = AutoConfig.from_pretrained(model_id, torchscript=True, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=amp_dtype, config=config, low_cpu_mem_usage=True, trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) model = model.eval() model = model.to(memory_format=torch.channels_last)

Intel(R) Extension for PyTorch*

#################### code changes #################### # noqa F401 model = ipex.llm.optimize( model, dtype=amp_dtype, inplace=True, deployment_mode=True, ) ###################################################### # noqa F401

generate args

num_beams = 1 if args.greedy else 4 generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=num_beams)

input prompt

prompt = args.prompt input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1) print("---- Prompt size:", input_size) prompt = [prompt] * args.batch_size

inference

with torch.no_grad(), torch.inference_mode(), torch.cpu.amp.autocast( enabled=amp_enabled ): input_ids = tokenizer(prompt, return_tensors="pt").input_ids gen_ids = model.generate( input_ids, max_new_tokens=args.max_new_tokens, **generate_kwargs ) gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True) input_tokens_lengths = [x.shape[0] for x in input_ids] output_tokens_lengths = [x.shape[0] for x in gen_ids] total_new_tokens = [ o - i for i, o in zip(input_tokens_lengths, output_tokens_lengths) ] print(gen_text, total_new_tokens, flush=True)

Weight Only Quantization INT8/INT4

import torch

#################### code changes #################### # noqa F401 import intel_extension_for_pytorch as ipex

###################################################### # noqa F401 import argparse from transformers import ( AutoConfig, AutoModelForCausalLM, AutoTokenizer, )

args

parser = argparse.ArgumentParser( "Generation script (weight only quantization path)", add_help=False ) parser.add_argument( "--dtype", type=str, choices=["float32", "bfloat16"], default="float32", help="choose the weight dtype and whether to enable auto mixed precision or not", ) parser.add_argument( "--max-new-tokens", default=32, type=int, help="output max new tokens" ) parser.add_argument( "--prompt", default="What are we having for dinner?", type=str, help="input prompt" ) parser.add_argument("--greedy", action="store_true") parser.add_argument("--batch-size", default=1, type=int, help="batch size")

Intel(R) Extension for PyTorch*

#################### code changes #################### # noqa F401 parser.add_argument( "--lowp-mode", choices=["AUTO", "BF16", "FP32", "INT8", "FP16"], default="AUTO", type=str, help="low precision mode for weight only quantization. " "It indicates data type for computation for speedup at the cost " "of accuracy. Unrelated to activation or weight data type." "It is not supported yet to use lowp_mode=INT8 for INT8 weight, " "falling back to lowp_mode=BF16 implicitly in this case." "If set to AUTO, lowp_mode is determined by weight data type: " "lowp_mode=BF16 is used for INT8 weight " "and lowp_mode=INT8 used for INT4 weight", ) parser.add_argument( "--weight-dtype", choices=["INT8", "INT4"], default="INT8", type=str, help="weight data type for weight only quantization. Unrelated to activation" " data type or lowp-mode. If --low-precision-checkpoint is given, weight" " data type is always INT4 and this argument is not needed.", ) parser.add_argument( "--low-precision-checkpoint", default="", type=str, help="Low precision checkpoint file generated by calibration, such as GPTQ. It contains" " modified weights, scales, zero points, etc. For better accuracy of weight only" " quantization with INT4 weight.", ) ###################################################### # noqa F401 args = parser.parse_args() print(args)

dtype

amp_enabled = True if args.dtype != "float32" else False amp_dtype = getattr(torch, args.dtype)

load model

model_id = "facebook/opt-125m" config = AutoConfig.from_pretrained(model_id, torchscript=True, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=amp_dtype, config=config, low_cpu_mem_usage=True, trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) model = model.eval() model = model.to(memory_format=torch.channels_last)

Intel(R) Extension for PyTorch*

#################### code changes #################### # noqa F401 from intel_extension_for_pytorch.quantization import WoqWeightDtype

weight_dtype = ( WoqWeightDtype.INT4 if args.weight_dtype == "INT4" else WoqWeightDtype.INT8 )

if args.lowp_mode == "INT8": lowp_mode = ipex.quantization.WoqLowpMode.INT8 elif args.lowp_mode == "FP32": lowp_mode = ipex.quantization.WoqLowpMode.NONE elif args.lowp_mode == "FP16": lowp_mode = ipex.quantization.WoqLowpMode.FP16 elif args.lowp_mode == "BF16": lowp_mode = ipex.quantization.WoqLowpMode.BF16 else: # AUTO if args.low_precision_checkpoint != "" or weight_dtype == WoqWeightDtype.INT4: lowp_mode = ipex.quantization.WoqLowpMode.INT8 else: lowp_mode = ipex.quantization.WoqLowpMode.BF16

qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( weight_dtype=weight_dtype, lowp_mode=lowp_mode ) if args.low_precision_checkpoint != "": low_precision_checkpoint = torch.load(args.low_precision_checkpoint) else: low_precision_checkpoint = None model = ipex.llm.optimize( model.eval(), dtype=amp_dtype, quantization_config=qconfig, low_precision_checkpoint=low_precision_checkpoint, deployment_mode=True, inplace=True, )

###################################################### # noqa F401

generate args

num_beams = 1 if args.greedy else 4 generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=num_beams)

input prompt

prompt = args.prompt input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1) print("---- Prompt size:", input_size) prompt = [prompt] * args.batch_size

inference

with torch.no_grad(), torch.inference_mode(), torch.cpu.amp.autocast( enabled=amp_enabled ): input_ids = tokenizer(prompt, return_tensors="pt").input_ids gen_ids = model.generate( input_ids, max_new_tokens=args.max_new_tokens, **generate_kwargs ) gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True) input_tokens_lengths = [x.shape[0] for x in input_ids] output_tokens_lengths = [x.shape[0] for x in gen_ids] total_new_tokens = [ o - i for i, o in zip(input_tokens_lengths, output_tokens_lengths) ] print(gen_text, total_new_tokens, flush=True)

Note: Please check LLM Best Known Practice Pagefor detailed environment setup and LLM workload running instructions.

C++

To work with libtorch, C++ library of PyTorch, Intel® Extension for PyTorch* provides its C++ dynamic library as well. The C++ library is supposed to handle inference workload only, such as service deployment. For regular development, use the Python interface. Unlike using libtorch, no specific code changes are required. Compilation follows the recommended methodology with CMake. Detailed instructions can be found in PyTorch tutorial.

During compilation, Intel optimizations will be activated automatically once C++ dynamic library of Intel® Extension for PyTorch* is linked.

The example code below works for all data types.

example-app.cpp

#include <torch/script.h> #include #include

int main(int argc, const char* argv[]) { torch::jit::script::Module module; try { module = torch::jit::load(argv[1]); } catch (const c10::Error& e) { std::cerr << "error loading the model\n"; return -1; }

std::vectortorch::jit::IValue inputs; torch::Tensor input = torch::rand({1, 3, 224, 224}); inputs.push_back(input);

at::Tensor output = module.forward(inputs).toTensor(); std::cout << output.slice(/dim=/1, /start=/0, /end=/5) << std::endl; std::cout << "Execution finished" << std::endl;

return 0; }

CMakeLists.txt

cmake_minimum_required(VERSION 3.0 FATAL_ERROR) project(example-app)

find_package(IPEX REQUIRED)

add_executable(example-app example-app.cpp) target_link_libraries(example-app "${TORCH_IPEX_LIBRARIES}")

set_property(TARGET example-app PROPERTY CXX_STANDARD 17)

Command for compilation

$ cd examples/cpu/inference/cpp $ mkdir build $ cd build $ cmake -DCMAKE_PREFIX_PATH= .. $ make

If Found IPEX is shown as with a dynamic library path, the extension had been linked into the binary. This can be verified with Linux command ldd.

$ cmake -DCMAKE_PREFIX_PATH=/workspace/libtorch .. -- The C compiler identification is GNU XX.X.X -- The CXX compiler identification is GNU XX.X.X -- Detecting C compiler ABI info -- Detecting C compiler ABI info - done -- Check for working C compiler: /usr/bin/cc - skipped -- Detecting C compile features -- Detecting C compile features - done -- Detecting CXX compiler ABI info -- Detecting CXX compiler ABI info - done -- Check for working CXX compiler: /usr/bin/c++ - skipped -- Detecting CXX compile features -- Detecting CXX compile features - done CMake Warning at /workspace/libtorch/share/cmake/Torch/TorchConfig.cmake:22 (message): static library kineto_LIBRARY-NOTFOUND not found. Call Stack (most recent call first): /workspace/libtorch/share/cmake/Torch/TorchConfig.cmake:127 (append_torchlib_if_found) /workspace/libtorch/share/cmake/IPEX/IPEXConfig.cmake:84 (FIND_PACKAGE) CMakeLists.txt:4 (find_package)

-- Found Torch: /workspace/libtorch/lib/libtorch.so -- Found IPEX: /workspace/libtorch/lib/libintel-ext-pt-cpu.so -- Configuring done -- Generating done -- Build files have been written to: examples/cpu/inference/cpp/build

$ ldd example-app ... libtorch.so => /workspace/libtorch/lib/libtorch.so (0x00007f3cf98e0000) libc10.so => /workspace/libtorch/lib/libc10.so (0x00007f3cf985a000) libintel-ext-pt-cpu.so => /workspace/libtorch/lib/libintel-ext-pt-cpu.so (0x00007f3cf70fc000) libtorch_cpu.so => /workspace/libtorch/lib/libtorch_cpu.so (0x00007f3ce16ac000) ... libdnnl_graph.so.0 => /workspace/libtorch/lib/libdnnl_graph.so.0 (0x00007f3cde954000) ...