Unified (TensorFlow and Pytorch) — coremltools API Reference 8.0b1 documentation (original) (raw)

coremltools.converters._converters_entry.convert(model, source='auto', inputs=None, outputs=None, classifier_config=None, minimum_deployment_target=None, convert_to=None, compute_precision=None, skip_model_load=False, compute_units=ComputeUnit.ALL, package_dir=None, debug=False, pass_pipeline: PassPipeline | None = None, states=None)[source]

Convert a TensorFlow or PyTorch model to the Core ML model format as either a neural network or an ML program. Some parameters and requirements differ for TensorFlow and PyTorch conversions.

Parameters:

model

TensorFlow 1, TensorFlow 2, or PyTorch model in one of the following formats:

    * Path to a `.pt` file  

sourcestr (optional)

One of [auto, tensorflow, pytorch, milinternal]. autodetermines the framework automatically for most cases. RaisesValueError if it fails to determine the source framework.

inputslist of TensorType or ImageType

H : image height, W: image width

mlmodel = ct.convert(
torch_model,
inputs=[
ct.ImageType(shape=(1, 1, H, W), color_layout=ct.colorlayout.GRAYSCALE_FLOAT16)
],
minimum_deployment_target=ct.target.macOS13,
)

    * Number of elements in `inputs` must match the number of inputs of the PyTorch model.  
    * `inputs` may be a nested list or tuple.  
    * `TensorType` and `ImageType` must have the `shape` specified.  
    * If the `name` argument is specified with `TensorType` or`ImageType`, the converted Core ML model will have inputs with the same name.  
    * If `dtype` is missing:  
             * For `minimum_deployment_target <= ct.target.macOS12`, it defaults to float 32.  
             * For `minimum_deployment_target >= ct.target.macOS13`, and with `compute_precision` in float 16 precision. It defaults to float 16.  

outputslist of TensorType or ImageType (optional)

H: image height, W: image width

mlmodel = ct.convert(
torch_model,
inputs=[ct.ImageType(shape=(1, 3, H, W), color_layout=ct.colorlayout.RGB)],
outputs=[ct.ImageType(color_layout=ct.colorlayout.RGB)],
minimum_deployment_target=ct.target.macOS13,
)

classifier_configClassifierConfig class (optional)

The configuration if the MLModel is intended to be a classifier.

minimum_deployment_targetcoremltools.target enumeration (optional)

A member of the coremltools.target enum. The value of this parameter determines the type of the model representation produced by the converter. To learn about the differences between ML programs and neural networks, seeML Programs.

Invalid:

convert_to="mlprogram", minimum_deployment_target=coremltools.target.iOS14

Invalid:

convert_to="neuralnetwork", minimum_deployment_target=coremltools.target.iOS15

convert_tostr (optional)

Must be one of ['mlprogram', 'neuralnetwork', 'milinternal']. The value of this parameter determines the type of the model representation produced by the converter. To learn about the differences between ML programs and neural networks, seeML Programs.

compute_precisioncoremltools.precision enumeration or ct.transform.FP16ComputePrecision() (optional)

Use this argument to control the storage precision of the tensors in the ML program. Must be one of the following.

The above transform iterates through all the ops, looking at each op’s inputs and outputs. If they are of type float 32, castops are injected to convert those tensors (also known as vars) to type float 16. Similarly, int32 vars will also be cast to int16.

The above casts all the float32 tensors to be float 16, except the input/output tensors to any linear op. See more examples below.

skip_model_loadbool

Set to True to prevent coremltools from calling into the Core ML framework to compile and load the model, post-conversion. In that case, the returned model object cannot be used to make a prediction, but can be used to save with model.save(). This flag may be used to convert to a newer model type on an older Mac, which may raise a runtime warning if done without turning this flag on.

Example: Use this flag to suppress a runtime warning when converting to an ML program model on macOS 11, since an ML program can only be compiled and loaded from macOS12+.

Defaults to False.

compute_units: coremltools.ComputeUnit

The set of processing units the model can use to make predictions. After conversion, the model is loaded with the provided set of compute units and returned.

An enum with the following possible values:

package_dirstr

Post conversion, the model is saved at a temporary location and loaded to form the MLModel object ready for prediction.

debugbool

This flag should generally be False except for debugging purposes. Setting this flag to True produces the following behavior:

pass_pipelinePassPipeline

Manage graph passes. You can control which graph passes to run and the order of the graph passes. You can also specify options for each pass. See the details in the docstring of PassPipeline (coremltools/converters/mil/mil/passes/pass_pipeline.py).

We also provide a set of predefined pass pipelines that you can directly call.

states:

Create a stateful mlprogram model by providing the StateType in the states argument (for details see MIL Input Types). The stateful model is useful when converting a large language model with KV-Cache. The name of StateType must match the key of the PyTorch named_buffers() method in the source traced model.

The following example converts a torch model with a buffer called state_1.

class UpdateBufferModel(torch.nn.Module): def init(self): super(UpdateBufferModel, self).init() self.register_buffer( "state_1", torch.tensor(np.array([0, 0, 0], dtype=np.float32)) )

def forward(self, x):
    # In place update of the model state
    self.state_1.add_(x)
    return self.state_1

model = UpdateBufferModel() traced_model = torch.jit.trace(model, torch.tensor([1, 2, 3], dtype=torch.float32))

inputs = [ ct.TensorType(shape=(1, 2)), ] states = [ ct.StateType( wrapped_type=ct.TensorType( shape=(1, 2), ), name="state_1", ), ] mlmodel = ct.convert( traced_model, inputs=inputs, states=states, minimum_deployment_target=ct.target.iOS18, )

Returns:

modelcoremltools.models.MLModel or coremltools.converters.mil.Program

A Core ML MLModel object or MIL program object (see convert_to).

Examples

TensorFlow 1, 2 (model is a frozen graph):

with tf.Graph().as_default() as graph: x = tf.placeholder(tf.float32, shape=(1, 2, 3), name="input") y = tf.nn.relu(x, name="output")

Automatically infer inputs and outputs:

mlmodel = ct.convert(graph) test_input = np.random.rand(1, 2, 3) - 0.5 results = mlmodel.predict({"input": test_input}) print(results['output'])

TensorFlow 2 (model is a tf.Keras model path):

x = tf.keras.Input(shape=(32,), name='input') y = tf.keras.layers.Dense(16, activation='softmax')(x) keras_model = tf.keras.Model(x, y)

keras_model.save(h5_path) mlmodel = ct.convert(h5_path)

test_input = np.random.rand(2, 32) results = mlmodel.predict({'input': test_input}) print(results['Identity'])

PyTorch:

TorchScript Models:

model = torchvision.models.mobilenet_v2() model.eval() example_input = torch.rand(1, 3, 256, 256) traced_model = torch.jit.trace(model, example_input)

input = ct.TensorType(name='input_name', shape=(1, 3, 256, 256)) mlmodel = ct.convert(traced_model, inputs=[input]) results = mlmodel.predict({"input": example_input.numpy()}) print(results['1651']) # 1651 is the node name given by PyTorch's JIT

For more options see Conversion Options.