tfm.core.train_lib.run_experiment  |  TensorFlow v2.16.1 (original) (raw)

tfm.core.train_lib.run_experiment

Stay organized with collections Save and categorize content based on your preferences.

Runs train/eval configured by the experiment params.

tfm.core.train_lib.run_experiment(
    distribution_strategy: tf.distribute.Strategy,
    task: tfm.core.base_task.Task,
    mode: str,
    params: tfm.core.base_trainer.ExperimentConfig,
    model_dir: str,
    run_post_eval: bool = False,
    save_summary: bool = True,
    train_actions: Optional[List[orbit.Action]] = None,
    eval_actions: Optional[List[orbit.Action]] = None,
    trainer: Optional[tfm.core.base_trainer.Trainer] = None,
    controller_cls=orbit.Controller,
    summary_manager: Optional[orbit.utils.SummaryManager] = None,
    eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
    enable_async_checkpointing: bool = False
) -> Tuple[tf.keras.Model, Mapping[str, Any]]
Args
distribution_strategy A distribution distribution_strategy.
task A Task instance.
mode A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' or 'continuous_eval'.
params ExperimentConfig instance.
model_dir A 'str', a path to store model checkpoints and summaries.
run_post_eval Whether to run post eval once after training, metrics logs are returned.
save_summary Whether to save train and validation summary.
train_actions Optional list of Orbit train actions.
eval_actions Optional list of Orbit eval actions.
trainer the base_trainer.Trainer instance. It should be created within the strategy.scope().
controller_cls The controller class to manage the train and eval process. Must be a orbit.Controller subclass.
summary_manager Instance of the summary manager to override default summary manager.
eval_summary_manager Instance of the eval summary manager to override default eval summary manager.
enable_async_checkpointing Optional boolean indicating whether to enable async checkpoint saving.
Returns
A 2-tuple of (model, eval_logs). model: tf.keras.Model instance. eval_logs: returns eval metrics logs when run_post_eval is set to True, otherwise, returns {}.

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.

Last updated 2024-02-02 UTC.