From 4e2c1f3a4da7436a861dd444035e6e27a3f1c6b7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 8 Sep 2022 17:46:03 +0200 Subject: [PATCH] Add config docs (#429) * advance * finish * finish --- docs/source/api/configuration.mdx | 21 +++----- src/diffusers/__init__.py | 1 + src/diffusers/configuration_utils.py | 79 +++++++++++++++++++++++++--- src/diffusers/modeling_utils.py | 12 ++--- 4 files changed, 83 insertions(+), 30 deletions(-) diff --git a/docs/source/api/configuration.mdx b/docs/source/api/configuration.mdx index 5c435dc8..45176f55 100644 --- a/docs/source/api/configuration.mdx +++ b/docs/source/api/configuration.mdx @@ -10,19 +10,14 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Models +# Configuration -Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models. -The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$. -The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub. +In Diffusers, schedulers of type [`schedulers.scheduling_utils.SchedulerMixin`], and models of type [`ModelMixin`] inherit from [`ConfigMixin`] which conveniently takes care of storing all parameters that are +passed to the respective `__init__` methods in a JSON-configuration file. -## API +TODO(PVP) - add example and better info here -Models should provide the `def forward` function and initialization of the model. -All saving, loading, and utilities should be in the base ['ModelMixin'] class. - -## Examples - -- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3. -- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991). -- TODO: mention VAE / SDE score estimation \ No newline at end of file +## ConfigMixin +[[autodoc]] ConfigMixin + - from_config + - save_config diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 215bfdf3..5d7015e5 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -9,6 +9,7 @@ from .utils import ( __version__ = "0.3.0.dev0" +from .configuration_utils import ConfigMixin from .modeling_utils import ModelMixin from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel from .onnx_utils import OnnxRuntimeModel diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 053ccd64..fbe75f3f 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -37,9 +37,16 @@ _re_configuration_file = re.compile(r"config\.(.*)\.json") class ConfigMixin: r""" - Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as - methods for loading/downloading/saving configurations. + Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all + methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with + - [`~ConfigMixin.from_config`] + - [`~ConfigMixin.save_config`] + Class attributes: + - **config_name** (`str`) -- A filename under which the config should stored when calling + [`~ConfigMixin.save_config`] (should be overriden by parent class). + - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be + overriden by parent class). """ config_name = None ignore_for_config = [] @@ -74,8 +81,6 @@ class ConfigMixin: Args: save_directory (`str` or `os.PathLike`): Directory where the configuration JSON file will be saved (will be created if it does not exist). - kwargs: - Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ if os.path.isfile(save_directory): raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") @@ -90,6 +95,64 @@ class ConfigMixin: @classmethod def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): + r""" + Instantiate a Python class from a pre-defined JSON-file. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g., + `./my_model_directory/`. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + Passing `use_auth_token=True`` is required when you want to use a private model. + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + """ config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) @@ -298,10 +361,10 @@ class FrozenDict(OrderedDict): def register_to_config(init): - """ - Decorator to apply on the init of classes inheriting from `ConfigMixin` so that all the arguments are automatically - sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be - registered in the config, use the `ignore_for_config` class variable + r""" + Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are + automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that + shouldn't be registered in the config, use the `ignore_for_config` class variable Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! """ diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 39e326de..fb613614 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -119,8 +119,6 @@ class ModelMixin(torch.nn.Module): [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading and saving models. - Class attributes: - - **config_name** ([`str`]) -- A filename under which the model should be stored when calling [`~modeling_utils.ModelMixin.save_pretrained`]. """ @@ -200,10 +198,9 @@ class ModelMixin(torch.nn.Module): Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a - user or organization name, like `dbmdz/bert-base-german-cased`. - - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], - e.g., `./my_model_directory/`. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the @@ -236,9 +233,6 @@ class ModelMixin(torch.nn.Module): problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please refer to the mirror site for more information. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to update the [`ConfigMixin`] of the model (after it being loaded). - Passing `use_auth_token=True`` is required when you want to use a private model.