Rework how PreTrainedModel.from_pretrained handles its arguments by anlsh · Pull Request #866 · huggingface/transformers (original) (raw)

Unification of the from_pretrained functions belonging to various modules (GPT2PreTrainedModel, OpenAIGPTPreTrainedModel, BertPreTrainedModel) brought changes to the function's argument handling which don't cause any issues within the repository itself (afaik), but have the potential to break a variety of downstream code (eg. my own).

In the last release of pytorch_transformers (v0.6.2), the from_pretrained functions took in *args and **kwargs and passed them directly to the relevant model's constructor (perhaps with some processing along the way). For a typical example, see from_pretrained's signature in modeling.py here https://github.com/huggingface/pytorch-transformers/blob/b832d5bb8a6dfc5965015b828e577677eace601e/pytorch_pretrained_bert/modeling.py#L526

and the relevant usage of said arguments (after some small modifications) https://github.com/huggingface/pytorch-transformers/blob/b832d5bb8a6dfc5965015b828e577677eace601e/pytorch_pretrained_bert/modeling.py#L600

In the latest release, the function's signature remains unchanged but the *args and most of the **kwargs parameters, in particular pretty much anything not explicitly accessed in [1]
https://github.com/huggingface/pytorch-transformers/blob/b33a385091de604afb566155ec03329b84c96926/pytorch_transformers/modeling_utils.py#L354-L358

is ignored. If a key of kwargs is shared with the relevant model's configuration file then its value is still used to override said key (see the relevant logic here), but the current architecture breaks, for example, the following pattern which was previously possible.

class UsefulSubclass(BertForSequenceClassification)
    def __init__(self, *args, useful_argument, **kwargs):
        super().__init__(*args, **kwargs)
        *logic*

...
bert = UsefulSubclass.from_pretrained(model_name, useful_argument=42).

What's more, if these arguments have default values declared in __init__ then the entire pattern is broken silently: because these default values will never be overwritten via pretrained instantiation. Thus end users might continue running experiments passing different values of useful_argument to from_pretrained, unaware that nothing is actually being changed

As evidenced by issue #833, I'm not the only one whose code was broken. This commit implements behavior which is a compromise between the old and new behaviors. From my docstring:

If config is None, then **kwargs will be passed to the model.
If config is *not* None, then kwargs will be used to
override any keys shared with the default configuration for the
given pretrained_model_name_or_path, and only the unshared
key/value pairs will be passed to the model.

It would actually be ideal to avoid mixing configuration and model parameters entirely (via some sort of model_args parameter for example): however this fix has the advantages of

  1. Not breaking code written during the pytorch-pretrained-bert era
  2. Preserving (to the extent possible) the usage of the from_pretrained.**kwargs parameter introduced with pytorch-transformers

I have also included various other (smaller) changes in this pull request: