tff.learning.models.functional_model_from_keras | TensorFlow Federated (original) (raw)
tff.learning.models.functional_model_from_keras
Converts a tf.keras.Model to a tff.learning.models.FunctionalModel.
tff.learning.models.functional_model_from_keras(
keras_model: Union[tf.keras.Model, Callable[[], tf.keras.Model]],
loss_fn: tf.keras.losses.Loss,
input_spec: Union[Sequence[Any], Mapping[str, Any]],
metrics_constructor: Optional[Union[keras_utils.MetricConstructor, keras_utils.
MetricsConstructor, keras_utils.MetricConstructors]] = None
) -> tff.learning.models.FunctionalModel
Used in the notebooks
| Used in the tutorials |
|---|
| Building Your Own Federated Learning Algorithm Composing Learning Algorithms High-performance simulations with TFF Working with tff's ClientData. |
This method doesn't support loss functions scaled by sample weights at the current state. Keras models with non-None sample weights will fail because sample weights aren't supported in model serialization and deserialization.
| Args | |
|---|---|
| keras_model | A tf.keras.Model object, should be uncompiled. If compiled, the metrics, optimizer, and loss function will be ignored. Note: models that have multiple outputs will send all outputs to the loss_fn. |
| loss_fn | A tf.keras.losses.Loss object. |
| input_spec | A structure of tf.TensorSpec defining the input to the model. |
| metrics_constructor | An optional callable that must be compatible withtff.learning.metrics.create_functional_metric_fns. |
| Returns |
|---|
| A tff.learning.models.FunctionalModel. |
| Raises | |
|---|---|
| KerasFunctionalModelError | If the following conditions: 1) the Keras model contains a batch normalization layer, 2) the Keras model is with non-trainable variable, 3) error occurs when converting the Keras model, 4) the Keras model shares variable across layers, 5) the FunctionalModel is used outside of a tff.tensorflow.computation decorated callable or a graph context, 6) the Keras model contains a loss function with non-None sample weights. |
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-09-20 UTC.