Models (original) (raw)

The base classes PreTrainedModel, TFPreTrainedModel, andFlaxPreTrainedModel implement the common methods for loading/saving a model either from a local file or directory, or from a pretrained model configuration provided by the library (downloaded from HuggingFace’s AWS S3 repository).

PreTrainedModel and TFPreTrainedModel also implement a few methods which are common among all the models to:

The other methods that are common to each model are defined in ModuleUtilsMixin(for the PyTorch models) and ~modeling_tf_utils.TFModuleUtilsMixin (for the TensorFlow models) or for text generation, GenerationMixin (for the PyTorch models),TFGenerationMixin (for the TensorFlow models) andFlaxGenerationMixin (for the Flax/JAX models).

PreTrainedModel

class transformers.PreTrainedModel

< source >

( config: PretrainedConfig *inputs **kwargs )

Base class for all models.

PreTrainedModel takes care of storing the configuration of the models and handles methods for loading, downloading and saving models as well as a few methods common to all models to:

Class attributes (overridden by derived classes):

push_to_hub

< source >

( repo_id: str use_temp_dir: typing.Optional[bool] = None commit_message: typing.Optional[str] = None private: typing.Optional[bool] = None token: typing.Union[bool, str, NoneType] = None max_shard_size: typing.Union[int, str, NoneType] = '5GB' create_pr: bool = False safe_serialization: bool = True revision: typing.Optional[str] = None commit_description: typing.Optional[str] = None tags: typing.Optional[list[str]] = None **deprecated_kwargs )

Parameters

Upload the model file to the 🤗 Model Hub.

Examples:

from transformers import AutoModel

model = AutoModel.from_pretrained("google-bert/bert-base-cased")

model.push_to_hub("my-finetuned-bert")

model.push_to_hub("huggingface/my-finetuned-bert")

add_model_tags

< source >

( tags: typing.Union[typing.List[str], str] )

Parameters

Add custom tags into the model that gets pushed to the Hugging Face Hub. Will not overwrite existing tags in the model.

Examples:

from transformers import AutoModel

model = AutoModel.from_pretrained("google-bert/bert-base-cased")

model.add_model_tags(["custom", "custom-bert"])

model.push_to_hub("my-custom-bert")

can_generate

< source >

( ) → bool

Whether this model can generate sequences with .generate().

Returns whether this model can generate sequences with .generate() from the GenerationMixin.

Under the hood, on classes where this function returns True, some generation-specific changes are triggered: for instance, the model instance will have a populated generation_config attribute.

Potentially dequantize the model in case it has been quantized by a quantization method that support dequantization.

Removes the _require_grads_hook.

Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping the model weights fixed.

from_pretrained

< source >

( pretrained_model_name_or_path: typing.Union[str, os.PathLike, NoneType] *model_args config: typing.Union[transformers.configuration_utils.PretrainedConfig, str, os.PathLike, NoneType] = None cache_dir: typing.Union[str, os.PathLike, NoneType] = None ignore_mismatched_sizes: bool = False force_download: bool = False local_files_only: bool = False token: typing.Union[str, bool, NoneType] = None revision: str = 'main' use_safetensors: typing.Optional[bool] = None weights_only: bool = True **kwargs )

Parameters

Parameters for big model inference

Instantiate a pretrained pytorch model from a pre-trained model configuration.

The model is set in evaluation mode by default using model.eval() (Dropout modules are deactivated). To train the model, you should first set it back in training mode with model.train().

The warning Weights from XXX not initialized from pretrained model means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task.

The warning Weights from XXX not used in YYY means that the layer XXX is not used by YYY, therefore those weights are discarded.

Activate the special “offline-mode” to use this method in a firewalled environment.

Examples:

from transformers import BertConfig, BertModel

model = BertModel.from_pretrained("google-bert/bert-base-uncased")

model = BertModel.from_pretrained("./test/saved_model/")

model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True) assert model.config.output_attentions == True

config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json") model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)

model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True)

get_compiled_call

< source >

( compile_config: CompileConfig )

Return a torch.compile‘d version of self.__call__. This is useful to dynamically choose between non-compiled/compiled forward during inference, especially to switch between prefill (where we don’t want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding (where we want the speed-ups of compiled version with static shapes).

get_input_embeddings

< source >

( ) → nn.Module

A torch module mapping vocabulary to hidden states.

Returns the model’s input embeddings.

get_memory_footprint

< source >

( return_buffers = True )

Parameters

Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2

get_output_embeddings

< source >

( ) → nn.Module

A torch module mapping hidden states to vocabulary.

Returns the model’s output embeddings.

Return the parameter or buffer given by target if it exists, otherwise throw an error. This combinesget_parameter() and get_buffer() in a single handy function. Note that it only work if target is a leaf of the model.

Deactivates gradient checkpointing for the current model.

Note that in other frameworks this feature can be referred to as “activation checkpointing” or “checkpoint activations”.

gradient_checkpointing_enable

< source >

( gradient_checkpointing_kwargs = None )

Parameters

Activates gradient checkpointing for the current model.

Note that in other frameworks this feature can be referred to as “activation checkpointing” or “checkpoint activations”.

We pass the __call__ method of the modules instead of forward because __call__ attaches all the hooks of the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2

If needed prunes and maybe initializes weights. If using a custom PreTrainedModel, you need to implement any initialization logic in _init_weights.

A method executed at the end of each Transformer model initialization, to execute code that needs the model’s modules properly initialized (such as weight initialization).

prune_heads

< source >

( heads_to_prune: typing.Dict[int, typing.List[int]] )

Parameters

Prunes heads of the base model.

register_for_auto_class

< source >

( auto_class = 'AutoModel' )

Parameters

Register this class with a given auto class. This should only be used for custom models as the ones in the library are already mapped with an auto class.

This API is experimental and may have some slight breaking changes in the next releases.

resize_token_embeddings

< source >

( new_num_tokens: typing.Optional[int] = None pad_to_multiple_of: typing.Optional[int] = None mean_resizing: bool = True ) → torch.nn.Embedding

Parameters

Returns

torch.nn.Embedding

Pointer to the input tokens Embeddings Module of the model.

Resizes input token embeddings matrix of the model if new_num_tokens != config.vocab_size.

Takes care of tying weights embeddings afterwards if the model class has a tie_weights() method.

Reverts the transformation from to_bettertransformer() so that the original modeling is used, for example in order to save the model.

save_pretrained

< source >

( save_directory: typing.Union[str, os.PathLike] is_main_process: bool = True state_dict: typing.Optional[dict] = None save_function: typing.Callable = <function save at 0x7fc3c088ed40> push_to_hub: bool = False max_shard_size: typing.Union[int, str] = '5GB' safe_serialization: bool = True variant: typing.Optional[str] = None token: typing.Union[str, bool, NoneType] = None save_peft_format: bool = True **kwargs )

Parameters

Save a model and its configuration file to a directory, so that it can be re-loaded using thefrom_pretrained() class method.

set_input_embeddings

< source >

( value: Module )

Parameters

Set model’s input embeddings.

Tie the weights between the input embeddings and the output embeddings.

If the torchscript flag is set in the configuration, can’t handle parameter sharing so we are cloning the weights instead.

warn_if_padding_and_no_attention_mask

< source >

( input_ids attention_mask )

Shows a one-time warning if the input_ids appear to contain padding and no attention mask was given.

Custom models should also include a _supports_assign_param_buffer, which determines if superfast init can apply on the particular model. Signs that your model needs this are if test_save_and_load_from_pretrained fails. If so, set this to False.

ModuleUtilsMixin

class transformers.modeling_utils.ModuleUtilsMixin

< source >

( )

A few utilities for torch.nn.Modules, to be used as a mixin.

Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.

Increase in memory consumption is stored in a mem_rss_diff attribute for each module and can be reset to zero with model.reset_memory_hooks_state().

estimate_tokens

< source >

( input_dict: typing.Dict[str, typing.Union[torch.Tensor, typing.Any]] ) → int

Parameters

The total number of tokens.

Helper function to estimate the total number of tokens from the model inputs.

floating_point_ops

< source >

( input_dict: typing.Dict[str, typing.Union[torch.Tensor, typing.Any]] exclude_embeddings: bool = True ) → int

Parameters

The number of floating-point operations.

Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a batch with this transformer model. Default approximation neglects the quadratic dependency on the number of tokens (valid if 12 * d_model << sequence_length) as laid out in this paper section 2.1. Should be overridden for transformers with parameter re-use e.g. Albert or Universal Transformers, or if doing long-range modeling with very high sequence lengths.

get_extended_attention_mask

< source >

( attention_mask: Tensor input_shape: typing.Tuple[int] device: device = None dtype: torch.float32 = None )

Parameters

Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

get_head_mask

< source >

( head_mask: typing.Optional[torch.Tensor] num_hidden_layers: int is_attention_chunked: bool = False )

Parameters

Prepare the head mask if needed.

invert_attention_mask

< source >

( encoder_attention_mask: Tensor ) → torch.Tensor

Parameters

The inverted attention mask.

Invert an attention mask (e.g., switches 0. and 1.).

num_parameters

< source >

( only_trainable: bool = False exclude_embeddings: bool = False ) → int

Parameters

The number of parameters.

Get number of (optionally, trainable or non-embeddings) parameters in the module.

TFPreTrainedModel

class transformers.TFPreTrainedModel

< source >

( config *inputs **kwargs )

Base class for all TF models.

TFPreTrainedModel takes care of storing the configuration of the models and handles methods for loading, downloading and saving models as well as a few methods common to all models to:

Class attributes (overridden by derived classes):

push_to_hub

< source >

( repo_id: str use_temp_dir: Optional[bool] = None commit_message: Optional[str] = None private: Optional[bool] = None max_shard_size: Optional[Union[int, str]] = '10GB' token: Optional[Union[bool, str]] = None use_auth_token: Optional[Union[bool, str]] = None create_pr: bool = False **base_model_card_args )

Parameters

Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in repo_path_or_name.

Examples:

from transformers import TFAutoModel

model = TFAutoModel.from_pretrained("google-bert/bert-base-cased")

model.push_to_hub("my-finetuned-bert")

model.push_to_hub("huggingface/my-finetuned-bert")

can_generate

< source >

( ) → bool

Whether this model can generate sequences with .generate().

Returns whether this model can generate sequences with .generate().

compile

< source >

( optimizer = 'rmsprop' loss = 'auto_with_warning' metrics = None loss_weights = None weighted_metrics = None run_eagerly = None steps_per_execution = None **kwargs )

This is a thin wrapper that sets the model’s loss output head as the loss if the user does not specify a loss function themselves.

create_model_card

< source >

( output_dir model_name: str language: Optional[str] = None license: Optional[str] = None tags: Optional[str] = None finetuned_from: Optional[str] = None tasks: Optional[str] = None dataset_tags: Optional[Union[str, List[str]]] = None dataset: Optional[Union[str, List[str]]] = None dataset_args: Optional[Union[str, List[str]]] = None )

Parameters

Creates a draft of a model card using the information available to the Trainer.

from_pretrained

< source >

( pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] *model_args config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None cache_dir: Optional[Union[str, os.PathLike]] = None ignore_mismatched_sizes: bool = False force_download: bool = False local_files_only: bool = False token: Optional[Union[str, bool]] = None revision: str = 'main' use_safetensors: Optional[bool] = None **kwargs )

Parameters

Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.

The warning Weights from XXX not initialized from pretrained model means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task.

The warning Weights from XXX not used in YYY means that the layer XXX is not used by YYY, therefore those weights are discarded.

Examples:

from transformers import BertConfig, TFBertModel

model = TFBertModel.from_pretrained("google-bert/bert-base-uncased")

model = TFBertModel.from_pretrained("./test/saved_model/")

model = TFBertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True) assert model.config.output_attentions == True

config = BertConfig.from_json_file("./pt_model/my_pt_model_config.json") model = TFBertModel.from_pretrained("./pt_model/my_pytorch_model.bin", from_pt=True, config=config)

get_bias

< source >

( ) → tf.Variable

The weights representing the bias, None if not an LM model.

Dict of bias attached to an LM head. The key represents the name of the bias attribute.

get_head_mask

< source >

( head_mask: tf.Tensor | None num_hidden_layers: int )

Parameters

Prepare the head mask if needed.

get_input_embeddings

< source >

( ) → tf.Variable

The embeddings layer mapping vocabulary to hidden states.

Returns the model’s input embeddings layer.

get_lm_head

< source >

( ) → keras.layers.Layer

Returns

keras.layers.Layer

The LM head layer if the model has one, None if not.

The LM Head layer. This method must be overwritten by all the models that have a lm head.

get_output_embeddings

< source >

( ) → tf.Variable

The new weights mapping vocabulary to hidden states.

Returns the model’s output embeddings

get_output_layer_with_bias

< source >

( ) → keras.layers.Layer

Returns

keras.layers.Layer

The layer that handles the bias, None if not an LM model.

Get the layer that handles a bias attribute in case the model has an LM head with weights tied to the embeddings

get_prefix_bias_name

< source >

( ) → str

The _prefix name of the bias.

Get the concatenated _prefix name of the bias from the model name to the parent layer

prepare_tf_dataset

< source >

( dataset: 'datasets.Dataset' batch_size: int = 8 shuffle: bool = True tokenizer: Optional['PreTrainedTokenizerBase'] = None collate_fn: Optional[Callable] = None collate_fn_args: Optional[Dict[str, Any]] = None drop_remainder: Optional[bool] = None prefetch: bool = True ) → Dataset

Parameters

A tf.data.Dataset which is ready to pass to the Keras API.

Wraps a HuggingFace Dataset as a tf.data.Dataset with collation and batching. This method is designed to create a “ready-to-use” dataset that can be passed directly to Keras methods like fit() without further modification. The method will drop columns from the dataset if they don’t match input names for the model. If you want to specify the column names to return rather than using the names that match this model, we recommend using Dataset.to_tf_dataset() instead.

prune_heads

< source >

( heads_to_prune )

Parameters

Prunes heads of the base model.

register_for_auto_class

< source >

( auto_class = 'TFAutoModel' )

Parameters

Register this class with a given auto class. This should only be used for custom models as the ones in the library are already mapped with an auto class.

This API is experimental and may have some slight breaking changes in the next releases.

resize_token_embeddings

< source >

( new_num_tokens: Optional[int] = None ) → tf.Variable or keras.layers.Embedding

Parameters

Returns

tf.Variable or keras.layers.Embedding

Pointer to the input tokens of the model.

Resizes input token embeddings matrix of the model if new_num_tokens != config.vocab_size.

Takes care of tying weights embeddings afterwards if the model class has a tie_weights() method.

save_pretrained

< source >

( save_directory saved_model = False version = 1 push_to_hub = False signatures = None max_shard_size: Union[int, str] = '5GB' create_pr: bool = False safe_serialization: bool = False token: Optional[Union[str, bool]] = None **kwargs )

Parameters

Save a model and its configuration file to a directory, so that it can be re-loaded using thefrom_pretrained() class method.

serving

( inputs )

Parameters

Prepare the output of the saved model. Can be overridden if specific serving modifications are required.

set_bias

< source >

( value )

Parameters

Set all the bias in the LM head.

set_input_embeddings

< source >

( value )

Parameters

Set model’s input embeddings

set_output_embeddings

< source >

( value )

Parameters

Set model’s output embeddings

A modification of Keras’s default train_step that correctly handles matching outputs to labels for our models and supports directly training on the loss output head. In addition, it ensures input keys are copied to the labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure that they are available to the model during the forward pass.

A modification of Keras’s default train_step that correctly handles matching outputs to labels for our models and supports directly training on the loss output head. In addition, it ensures input keys are copied to the labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure that they are available to the model during the forward pass.

TFModelUtilsMixin

class transformers.modeling_tf_utils.TFModelUtilsMixin

< source >

( )

A few utilities for keras.Model, to be used as a mixin.

num_parameters

< source >

( only_trainable: bool = False ) → int

Parameters

The number of parameters.

Get the number of (optionally, trainable) parameters in the model.

FlaxPreTrainedModel

class transformers.FlaxPreTrainedModel

< source >

( config: PretrainedConfig module: Module input_shape: typing.Tuple = (1, 1) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True )

Base class for all models.

FlaxPreTrainedModel takes care of storing the configuration of the models and handles methods for loading, downloading and saving models.

Class attributes (overridden by derived classes):

push_to_hub

< source >

( repo_id: str use_temp_dir: typing.Optional[bool] = None commit_message: typing.Optional[str] = None private: typing.Optional[bool] = None token: typing.Union[bool, str, NoneType] = None max_shard_size: typing.Union[int, str, NoneType] = '5GB' create_pr: bool = False safe_serialization: bool = True revision: typing.Optional[str] = None commit_description: typing.Optional[str] = None tags: typing.Optional[list[str]] = None **deprecated_kwargs )

Parameters

Upload the model checkpoint to the 🤗 Model Hub.

Examples:

from transformers import FlaxAutoModel

model = FlaxAutoModel.from_pretrained("google-bert/bert-base-cased")

model.push_to_hub("my-finetuned-bert")

model.push_to_hub("huggingface/my-finetuned-bert")

Returns whether this model can generate sequences with .generate(). Returns:bool: Whether this model can generate sequences with .generate().

from_pretrained

< source >

( pretrained_model_name_or_path: typing.Union[str, os.PathLike] dtype: dtype = <class 'jax.numpy.float32'> *model_args config: typing.Union[transformers.configuration_utils.PretrainedConfig, str, os.PathLike, NoneType] = None cache_dir: typing.Union[str, os.PathLike, NoneType] = None ignore_mismatched_sizes: bool = False force_download: bool = False local_files_only: bool = False token: typing.Union[str, bool, NoneType] = None revision: str = 'main' **kwargs )

Parameters

Instantiate a pretrained flax model from a pre-trained model configuration.

The warning Weights from XXX not initialized from pretrained model means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task.

The warning Weights from XXX not used in YYY means that the layer XXX is not used by YYY, therefore those weights are discarded.

Examples:

from transformers import BertConfig, FlaxBertModel

model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")

model = FlaxBertModel.from_pretrained("./test/saved_model/")

config = BertConfig.from_json_file("./pt_model/config.json") model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config)

load_flax_sharded_weights

< source >

( shard_files ) → Dict

Parameters

A nested dictionary of the model parameters, in the expected format for flax models : {'model': {'params': {'...'}}}.

This is the same as flax.serialization.from_bytes(https:lax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint.

This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being loaded in the model.

register_for_auto_class

< source >

( auto_class = 'FlaxAutoModel' )

Parameters

Register this class with a given auto class. This should only be used for custom models as the ones in the library are already mapped with an auto class.

This API is experimental and may have some slight breaking changes in the next releases.

save_pretrained

< source >

( save_directory: typing.Union[str, os.PathLike] params = None push_to_hub = False max_shard_size = '10GB' token: typing.Union[str, bool, NoneType] = None safe_serialization: bool = False **kwargs )

Parameters

Save a model and its configuration file to a directory, so that it can be re-loaded using the[from_pretrained()](/docs/transformers/v4.51.3/en/main_classes/model#transformers.FlaxPreTrainedModel.from_pretrained) class method

to_bf16

< source >

( params: typing.Union[typing.Dict, flax.core.frozen_dict.FrozenDict] mask: typing.Any = None )

Parameters

Cast the floating-point params to jax.numpy.bfloat16. This returns a new params tree and does not cast the params in place.

This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.

Examples:

from transformers import FlaxBertModel

model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")

model.params = model.to_bf16(model.params)

from flax import traverse_util

model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") flat_params = traverse_util.flatten_dict(model.params) mask = { ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) ... for path in flat_params ... } mask = traverse_util.unflatten_dict(mask) model.params = model.to_bf16(model.params, mask)

to_fp16

< source >

( params: typing.Union[typing.Dict, flax.core.frozen_dict.FrozenDict] mask: typing.Any = None )

Parameters

Cast the floating-point params to jax.numpy.float16. This returns a new params tree and does not cast theparams in place.

This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full half-precision training or to save weights in float16 for inference in order to save memory and improve speed.

Examples:

from transformers import FlaxBertModel

model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")

model.params = model.to_fp16(model.params)

from flax import traverse_util

model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased") flat_params = traverse_util.flatten_dict(model.params) mask = { ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) ... for path in flat_params ... } mask = traverse_util.unflatten_dict(mask) model.params = model.to_fp16(model.params, mask)

to_fp32

< source >

( params: typing.Union[typing.Dict, flax.core.frozen_dict.FrozenDict] mask: typing.Any = None )

Parameters

Cast the floating-point params to jax.numpy.float32. This method can be used to explicitly convert the model parameters to fp32 precision. This returns a new params tree and does not cast the params in place.

Examples:

from transformers import FlaxBertModel

model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")

model.params = model.to_f16(model.params)

model.params = model.to_fp32(model.params)

Pushing to the Hub

class transformers.utils.PushToHubMixin

< source >

( )

A Mixin containing the functionality to push a model or tokenizer to the hub.

push_to_hub

< source >

( repo_id: str use_temp_dir: typing.Optional[bool] = None commit_message: typing.Optional[str] = None private: typing.Optional[bool] = None token: typing.Union[bool, str, NoneType] = None max_shard_size: typing.Union[int, str, NoneType] = '5GB' create_pr: bool = False safe_serialization: bool = True revision: typing.Optional[str] = None commit_description: typing.Optional[str] = None tags: typing.Optional[list[str]] = None **deprecated_kwargs )

Parameters

Upload the {object_files} to the 🤗 Model Hub.

Examples:

from transformers import {object_class}

{object} = {object_class}.from_pretrained("google-bert/bert-base-cased")

{object}.push_to_hub("my-finetuned-bert")

{object}.push_to_hub("huggingface/my-finetuned-bert")

Sharded checkpoints

transformers.modeling_utils.load_sharded_checkpoint

< source >

( model folder strict = True prefer_safe = True ) → NamedTuple

Parameters

A named tuple with missing_keys and unexpected_keys fields

This is the same astorch.nn.Module.load_state_dictbut for a sharded checkpoint.

This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being loaded in the model.

< > Update on GitHub