parent
5e6417e988
commit
4e2c1f3a4d
|
@ -10,19 +10,14 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
|||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Models
|
||||
# Configuration
|
||||
|
||||
Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models.
|
||||
The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$.
|
||||
The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub.
|
||||
In Diffusers, schedulers of type [`schedulers.scheduling_utils.SchedulerMixin`], and models of type [`ModelMixin`] inherit from [`ConfigMixin`] which conveniently takes care of storing all parameters that are
|
||||
passed to the respective `__init__` methods in a JSON-configuration file.
|
||||
|
||||
## API
|
||||
TODO(PVP) - add example and better info here
|
||||
|
||||
Models should provide the `def forward` function and initialization of the model.
|
||||
All saving, loading, and utilities should be in the base ['ModelMixin'] class.
|
||||
|
||||
## Examples
|
||||
|
||||
- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3.
|
||||
- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991).
|
||||
- TODO: mention VAE / SDE score estimation
|
||||
## ConfigMixin
|
||||
[[autodoc]] ConfigMixin
|
||||
- from_config
|
||||
- save_config
|
||||
|
|
|
@ -9,6 +9,7 @@ from .utils import (
|
|||
|
||||
__version__ = "0.3.0.dev0"
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
from .onnx_utils import OnnxRuntimeModel
|
||||
|
|
|
@ -37,9 +37,16 @@ _re_configuration_file = re.compile(r"config\.(.*)\.json")
|
|||
|
||||
class ConfigMixin:
|
||||
r"""
|
||||
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
|
||||
methods for loading/downloading/saving configurations.
|
||||
Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
|
||||
methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
|
||||
- [`~ConfigMixin.from_config`]
|
||||
- [`~ConfigMixin.save_config`]
|
||||
|
||||
Class attributes:
|
||||
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
||||
[`~ConfigMixin.save_config`] (should be overriden by parent class).
|
||||
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
||||
overriden by parent class).
|
||||
"""
|
||||
config_name = None
|
||||
ignore_for_config = []
|
||||
|
@ -74,8 +81,6 @@ class ConfigMixin:
|
|||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
|
@ -90,6 +95,64 @@ class ConfigMixin:
|
|||
|
||||
@classmethod
|
||||
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
|
||||
r"""
|
||||
Instantiate a Python class from a pre-defined JSON-file.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
|
||||
organization name, like `google/ddpm-celebahq-256`.
|
||||
- A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
|
||||
`./my_model_directory/`.
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
|
||||
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
|
||||
checkpoint with 3 labels).
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
Please refer to the mirror site for more information.
|
||||
|
||||
<Tip>
|
||||
|
||||
Passing `use_auth_token=True`` is required when you want to use a private model.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
|
||||
use this method in a firewalled environment.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
|
||||
|
@ -298,10 +361,10 @@ class FrozenDict(OrderedDict):
|
|||
|
||||
|
||||
def register_to_config(init):
|
||||
"""
|
||||
Decorator to apply on the init of classes inheriting from `ConfigMixin` so that all the arguments are automatically
|
||||
sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be
|
||||
registered in the config, use the `ignore_for_config` class variable
|
||||
r"""
|
||||
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
|
||||
automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
|
||||
shouldn't be registered in the config, use the `ignore_for_config` class variable
|
||||
|
||||
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
|
||||
"""
|
||||
|
|
|
@ -119,8 +119,6 @@ class ModelMixin(torch.nn.Module):
|
|||
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
|
||||
and saving models.
|
||||
|
||||
Class attributes:
|
||||
|
||||
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling
|
||||
[`~modeling_utils.ModelMixin.save_pretrained`].
|
||||
"""
|
||||
|
@ -200,10 +198,9 @@ class ModelMixin(torch.nn.Module):
|
|||
Can be either:
|
||||
|
||||
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
|
||||
user or organization name, like `dbmdz/bert-base-german-cased`.
|
||||
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`],
|
||||
e.g., `./my_model_directory/`.
|
||||
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
|
||||
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
|
||||
`./my_model_directory/`.
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
|
@ -236,9 +233,6 @@ class ModelMixin(torch.nn.Module):
|
|||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
Please refer to the mirror site for more information.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to update the [`ConfigMixin`] of the model (after it being loaded).
|
||||
|
||||
<Tip>
|
||||
|
||||
Passing `use_auth_token=True`` is required when you want to use a private model.
|
||||
|
|
Loading…
Reference in New Issue