tff.learning.models.functional_model_from_keras  |  TensorFlow Federated (original) (raw)

tff.learning.models.functional_model_from_keras

Converts a tf.keras.Model to a tff.learning.models.FunctionalModel.

tff.learning.models.functional_model_from_keras(
    keras_model: Union[tf.keras.Model, Callable[[], tf.keras.Model]],
    loss_fn: tf.keras.losses.Loss,
    input_spec: Union[Sequence[Any], Mapping[str, Any]],
    metrics_constructor: Optional[Union[keras_utils.MetricConstructor, keras_utils.
        MetricsConstructor, keras_utils.MetricConstructors]] = None
) -> tff.learning.models.FunctionalModel

Used in the notebooks

Used in the tutorials
Building Your Own Federated Learning Algorithm Composing Learning Algorithms High-performance simulations with TFF Working with tff's ClientData.

This method doesn't support loss functions scaled by sample weights at the current state. Keras models with non-None sample weights will fail because sample weights aren't supported in model serialization and deserialization.

Args
keras_model A tf.keras.Model object, should be uncompiled. If compiled, the metrics, optimizer, and loss function will be ignored. Note: models that have multiple outputs will send all outputs to the loss_fn.
loss_fn A tf.keras.losses.Loss object.
input_spec A structure of tf.TensorSpec defining the input to the model.
metrics_constructor An optional callable that must be compatible withtff.learning.metrics.create_functional_metric_fns.
Returns
A tff.learning.models.FunctionalModel.
Raises
KerasFunctionalModelError If the following conditions: 1) the Keras model contains a batch normalization layer, 2) the Keras model is with non-trainable variable, 3) error occurs when converting the Keras model, 4) the Keras model shares variable across layers, 5) the FunctionalModel is used outside of a tff.tensorflow.computation decorated callable or a graph context, 6) the Keras model contains a loss function with non-None sample weights.

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2024-09-20 UTC.