[Variant] Add "variant" as input kwarg so to have better UX when downloading no_ema or fp16 weights (#2305)

* [Variant] Add variant loading mechanism

* clean

* improve further

* up

* add tests

* add some first tests

* up

* up

* use path splittetx

* add deprecate

* deprecation warnings

* improve docs

* up

* up

* up

* fix tests

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* correct code format

* fix warning

* finish

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Update docs/source/en/using-diffusers/loading.mdx

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Apply suggestions from code review

Co-authored-by: Will Berman <wlbberman@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>

* correct loading docs

* finish

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Will Berman <wlbberman@gmail.com>
This commit is contained in:
Patrick von Platen 2023-02-16 12:02:58 +02:00 committed by GitHub
parent e3ddbe25ed
commit e5810e686e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 771 additions and 104 deletions

View File

@ -23,31 +23,50 @@ In the following we explain in-detail how to easily load:
## Loading pipelines
The [`DiffusionPipeline`] class is the easiest way to access any diffusion model that is [available on the Hub](https://huggingface.co/models?library=diffusers). Let's look at an example on how to download [CompVis' Latent Diffusion model](https://huggingface.co/CompVis/ldm-text2im-large-256).
The [`DiffusionPipeline`] class is the easiest way to access any diffusion model that is [available on the Hub](https://huggingface.co/models?library=diffusers). Let's look at an example on how to download [Runway's Stable Diffusion model](https://huggingface.co/runwayml/stable-diffusion-v1-5).
```python
from diffusers import DiffusionPipeline
repo_id = "CompVis/ldm-text2im-large-256"
ldm = DiffusionPipeline.from_pretrained(repo_id)
repo_id = "runwayml/stable-diffusion-v1-5"
pipe = DiffusionPipeline.from_pretrained(repo_id)
```
Here [`DiffusionPipeline`] automatically detects the correct pipeline (*i.e.* [`LDMTextToImagePipeline`]), downloads and caches all required configuration and weight files (if not already done so), and finally returns a pipeline instance, called `ldm`.
The pipeline instance can then be called using [`LDMTextToImagePipeline.__call__`] (i.e., `ldm("image of a astronaut riding a horse")`) for text-to-image generation.
Here [`DiffusionPipeline`] automatically detects the correct pipeline (*i.e.* [`StableDiffusionPipeline`]), downloads and caches all required configuration and weight files (if not already done so), and finally returns a pipeline instance, called `pipe`.
The pipeline instance can then be called using [`StableDiffusionPipeline.__call__`] (i.e., `pipe("image of a astronaut riding a horse")`) for text-to-image generation.
Instead of using the generic [`DiffusionPipeline`] class for loading, you can also load the appropriate pipeline class directly. The code snippet above yields the same instance as when doing:
```python
from diffusers import LDMTextToImagePipeline
from diffusers import StableDiffusionPipeline
repo_id = "CompVis/ldm-text2im-large-256"
ldm = LDMTextToImagePipeline.from_pretrained(repo_id)
repo_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(repo_id)
```
Diffusion pipelines like `LDMTextToImagePipeline` often consist of multiple components. These components can be both parameterized models, such as `"unet"`, `"vqvae"` and "bert", tokenizers or schedulers. These components can interact in complex ways with each other when using the pipeline in inference, *e.g.* for [`LDMTextToImagePipeline`] or [`StableDiffusionPipeline`] the inference call is explained [here](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work).
<Tip>
Many checkpoints, such as [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) and [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) can be used for multiple tasks, *e.g.* *text-to-image* or *image-to-image*.
If you want to use those checkpoints for a task that is different from the default one, you have to load it directly from the corresponding task-specific pipeline class:
```python
from diffusers import StableDiffusionImg2ImgPipeline
repo_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(repo_id)
```
</Tip>
Diffusion pipelines like `StableDiffusionPipeline` or `StableDiffusionImg2ImgPipeline` consist of multiple components. These components can be both parameterized models, such as `"unet"`, `"vae"` and `"text_encoder"`, tokenizers or schedulers.
These components often interact in complex ways with each other when using the pipeline in inference, *e.g.* for [`StableDiffusionPipeline`] the inference call is explained [here](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work).
The purpose of the [pipeline classes](./api/overview#diffusers-summary) is to wrap the complexity of these diffusion systems and give the user an easy-to-use API while staying flexible for customization, as will be shown later.
### Loading pipelines that require access request
<!---
THE FOLLOWING CAN BE UNCOMMENTED ONCE WE HAVE NEW MODELS WITH ACCESS REQUIREMENT
# Loading pipelines that require access request
Due to the capabilities of diffusion models to generate extremely realistic images, there is a certain danger that such models might be misused for unwanted applications, *e.g.* generating pornography or violent images.
In order to minimize the possibility of such unsolicited use cases, some of the most powerful diffusion models require users to acknowledge a license before being able to use the model. If the user does not agree to the license, the pipeline cannot be downloaded.
@ -94,6 +113,7 @@ stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, use_auth_token="<y
```
The final option to use pipelines that require access without having to rely on the Hugging Face Hub is to load the pipeline locally as explained in the next section.
-->
### Loading pipelines locally
@ -101,9 +121,9 @@ If you prefer to have complete control over the pipeline and its corresponding f
we recommend loading pipelines locally.
To load a diffusion pipeline locally, you first need to manually download the whole folder structure on your local disk and then pass a local path to the [`DiffusionPipeline.from_pretrained`]. Let's again look at an example for
[CompVis' Latent Diffusion model](https://huggingface.co/CompVis/ldm-text2im-large-256).
[Runway's Stable Diffusion Diffusion model](https://huggingface.co/runwayml/stable-diffusion-v1-5).
First, you should make use of [`git-lfs`](https://git-lfs.github.com/) to download the whole folder structure that has been uploaded to the [model repository](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main):
First, you should make use of [`git-lfs`](https://git-lfs.github.com/) to download the whole folder structure that has been uploaded to the [model repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main):
```
git lfs install
@ -178,105 +198,324 @@ stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(**components)
Note how the above code snippet makes use of [`DiffusionPipeline.components`].
### Loading variants
Diffusion Pipeline checkpoints can offer variants of the "main" diffusion pipeline checkpoint.
Such checkpoint variants are usually variations of the checkpoint that have advantages for specific use-cases and that are so similar to the "main" checkpoint that they **should not** be put in a new checkpoint.
A variation of a checkpoint has to have **exactly** the same serialization format and **exactly** the same model structure, including all weights having the same tensor shapes.
Examples of variations are different floating point types and non-ema weights. I.e. "fp16", "bf16", and "no_ema" are common variations.
#### Let's first talk about whats **not** checkpoint variant,
Checkpoint variants do **not** include different serialization formats (such as [safetensors](https://huggingface.co/docs/diffusers/main/en/using-diffusers/using_safetensors)) as weights in different serialization formats are
identical to the weights of the "main" checkpoint, just loaded in a different framework.
Also variants do not correspond to different model structures, *e.g.* [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) is not a variant of [stable-diffusion-2-0](https://huggingface.co/stabilityai/stable-diffusion-2) since the model structure is different (Stable Diffusion 1-5 uses a different `CLIPTextModel` compared to Stable Diffusion 2.0).
Pipeline checkpoints that are identical in model structure, but have been trained on different datasets, trained with vastly different training setups and thus correspond to different official releases (such as [Stable Diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) and [Stable Diffusion v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)) should probably be stored in individual repositories instead of as variations of eachother.
#### So what are checkpoint variants then?
Checkpoint variants usually consist of the checkpoint stored in "*low-precision, low-storage*" dtype so that less bandwith is required to download them, or of *non-exponential-averaged* weights that shall be used when continuing fine-tuning from the checkpoint.
Both use cases have clear advantages when their weights are considered variants: they share the same serialization format as the reference weights, and they correspond to a specialization of the "main" checkpoint which does not warrant a new model repository.
A checkpoint stored in [torch's half-precision / float16 format](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/) requires only half the bandwith and storage when downloading the checkpoint,
**but** cannot be used when continuing training or when running the checkpoint on CPU.
Similarly the *non-exponential-averaged* (or non-EMA) version of the checkpoint should be used when continuing fine-tuning of the model checkpoint, **but** should not be used when using the checkpoint for inference.
#### How to save and load variants
Saving a diffusion pipeline as a variant can be done by providing [`DiffusionPipeline.save_pretrained`] with the `variant` argument.
The `variant` extends the weight name by the provided variation, by changing the default weight name from `diffusion_pytorch_model.bin` to `diffusion_pytorch_model.{variant}.bin` or from `diffusion_pytorch_model.safetensors` to `diffusion_pytorch_model.{variant}.safetensors`. By doing so, one creates a variant of the pipeline checkpoint that can be loaded **instead** of the "main" pipeline checkpoint.
Let's have a look at how we could create a float16 variant of a pipeline. First, we load
the "main" variant of a checkpoint (stored in `float32` precision) into mixed precision format, using `torch_dtype=torch.float16`.
```py
from diffusers import DiffusionPipeline
import torch
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
```
Now all model components of the pipeline are stored in half-precision dtype. We can now save the
pipeline under a `"fp16"` variant as follows:
```py
pipe.save_pretrained("./stable-diffusion-v1-5", variant="fp16")
```
If we don't save into an existing `stable-diffusion-v1-5` folder the new folder would look as follows:
```
stable-diffusion-v1-5
├── feature_extractor
│   └── preprocessor_config.json
├── model_index.json
├── safety_checker
│   ├── config.json
│   └── pytorch_model.fp16.bin
├── scheduler
│   └── scheduler_config.json
├── text_encoder
│   ├── config.json
│   └── pytorch_model.fp16.bin
├── tokenizer
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── unet
│   ├── config.json
│   └── diffusion_pytorch_model.fp16.bin
└── vae
├── config.json
└── diffusion_pytorch_model.fp16.bin
```
As one can see, all model files now have a `.fp16.bin` extension instead of just `.bin`.
The variant now has to be loaded by also passing a `variant="fp16"` to [`DiffusionPipeline.from_pretrained`], e.g.:
```py
DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16)
```
works just fine, while:
```py
DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5", torch_dtype=torch.float16)
```
throws an Exception:
```
OSError: Error no file named diffusion_pytorch_model.bin found in directory ./stable-diffusion-v1-45/vae since we **only** stored the model
```
This is expected as we don't have any "non-variant" checkpoint files saved locally.
However, the whole idea of pipeline variants is that they can co-exist with the "main" variant,
so one would typically also save the "main" variant in the same folder. Let's do this:
```py
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.save_pretrained("./stable-diffusion-v1-5")
```
and upload the pipeline to the Hub under [diffusers/stable-diffusion-variants](https://huggingface.co/diffusers/stable-diffusion-variants).
The file structure [on the Hub](https://huggingface.co/diffusers/stable-diffusion-variants/tree/main) now looks as follows:
```
├── feature_extractor
│   └── preprocessor_config.json
├── model_index.json
├── safety_checker
│   ├── config.json
│   ├── pytorch_model.bin
│   └── pytorch_model.fp16.bin
├── scheduler
│   └── scheduler_config.json
├── text_encoder
│   ├── config.json
│   ├── pytorch_model.bin
│   └── pytorch_model.fp16.bin
├── tokenizer
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── unet
│   ├── config.json
│   ├── diffusion_pytorch_model.bin
│   ├── diffusion_pytorch_model.fp16.bin
└── vae
├── config.json
├── diffusion_pytorch_model.bin
└── diffusion_pytorch_model.fp16.bin
```
We can now both download the "main" and the "fp16" variant from the Hub. Both:
```py
pipe = DiffusionPipeline.from_pretrained("diffusers/stable-diffusion-variants")
```
and
```py
pipe = DiffusionPipeline.from_pretrained("diffusers/stable-diffusion-variants", variant="fp16")
```
works.
<Tip>
Note that Diffusers never downloads more checkpoints than needed. E.g. when downloading
the "main" variant, none of the "fp16.bin" files are downloaded and cached.
Only when the user specifies `variant="fp16"` are those files downloaded and cached.
</Tip>
Finally, there are cases where only some of the checkpoint files of the pipeline are of a certain
variation. E.g. it's usually only the UNet checkpoint that has both a *exponential-mean-averaged* (EMA) and a *non-exponential-mean-averaged* (non-EMA) version. All other model components, e.g. the text encoder, safety checker or variational auto-encoder usually don't have such a variation.
In such a case, one would upload just the UNet's checkpoint file with a `non_ema` version format (as done [here](https://huggingface.co/diffusers/stable-diffusion-variants/blob/main/unet/diffusion_pytorch_model.non_ema.bin)) and upon calling:
```python
pipe = DiffusionPipeline.from_pretrained("diffusers/stable-diffusion-variants", variant="non_ema")
```
the model will use only the "non_ema" checkpoint variant if it is available - otherwise it'll load the
"main" variation. In the above example, `variant="non_ema"` would therefore download the following file structure:
```
├── feature_extractor
│   └── preprocessor_config.json
├── model_index.json
├── safety_checker
│   ├── config.json
│   ├── pytorch_model.bin
├── scheduler
│   └── scheduler_config.json
├── text_encoder
│   ├── config.json
│   ├── pytorch_model.bin
├── tokenizer
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── unet
│   ├── config.json
│   └── diffusion_pytorch_model.non_ema.bin
└── vae
├── config.json
├── diffusion_pytorch_model.bin
```
In a nutshell, using `variant="{variant}"` will download all files that match the `{variant}` and if for a model component such a file variant is not present it will download the "main" variant. If neither a "main" or `{variant}` variant is available, an error will the thrown.
### How does loading work?
As a class method, [`DiffusionPipeline.from_pretrained`] is responsible for two things:
- Download the latest version of the folder structure required to run the `repo_id` with `diffusers` and cache them. If the latest folder structure is available in the local cache, [`DiffusionPipeline.from_pretrained`] will simply reuse the cache and **not** re-download the files.
- Load the cached weights into the _correct_ pipeline class one of the [officially supported pipeline classes](./api/overview#diffusers-summary) - and return an instance of the class. The _correct_ pipeline class is thereby retrieved from the `model_index.json` file.
The underlying folder structure of diffusion pipelines correspond 1-to-1 to their corresponding class instances, *e.g.* [`LDMTextToImagePipeline`] for [`CompVis/ldm-text2im-large-256`](https://huggingface.co/CompVis/ldm-text2im-large-256)
This can be understood better by looking at an example. Let's print out pipeline class instance `pipeline` we just defined:
The underlying folder structure of diffusion pipelines correspond 1-to-1 to their corresponding class instances, *e.g.* [`StableDiffusionPipeline`] for [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)
This can be better understood by looking at an example. Let's load a pipeline class instance `pipe` and print it:
```python
from diffusers import DiffusionPipeline
repo_id = "CompVis/ldm-text2im-large-256"
ldm = DiffusionPipeline.from_pretrained(repo_id)
print(ldm)
repo_id = "runwayml/stable-diffusion-v1-5"
pipe = DiffusionPipeline.from_pretrained(repo_id)
print(pipe)
```
*Output*:
```
LDMTextToImagePipeline {
"bert": [
"latent_diffusion",
"LDMBertModel"
StableDiffusionPipeline {
"feature_extractor": [
"transformers",
"CLIPFeatureExtractor"
],
"safety_checker": [
"stable_diffusion",
"StableDiffusionSafetyChecker"
],
"scheduler": [
"diffusers",
"DDIMScheduler"
"PNDMScheduler"
],
"text_encoder": [
"transformers",
"CLIPTextModel"
],
"tokenizer": [
"transformers",
"BertTokenizer"
"CLIPTokenizer"
],
"unet": [
"diffusers",
"UNet2DConditionModel"
],
"vqvae": [
"vae": [
"diffusers",
"AutoencoderKL"
]
}
```
First, we see that the official pipeline is the [`LDMTextToImagePipeline`], and second we see that the `LDMTextToImagePipeline` consists of 5 components:
- `"bert"` of class `LDMBertModel` as defined [in the pipeline](https://github.com/huggingface/diffusers/blob/cd502b25cf0debac6f98d27a6638ef95208d1ea2/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py#L664)
- `"scheduler"` of class [`DDIMScheduler`]
- `"tokenizer"` of class `BertTokenizer` as defined [in `transformers`](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer)
- `"unet"` of class [`UNet2DConditionModel`]
- `"vqvae"` of class [`AutoencoderKL`]
First, we see that the official pipeline is the [`StableDiffusionPipeline`], and second we see that the `StableDiffusionPipeline` consists of 7 components:
- `"feature_extractor"` of class `CLIPFeatureExtractor` as defined [in `transformers`](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPFeatureExtractor).
- `"safety_checker"` as defined [here](https://github.com/huggingface/diffusers/blob/e55687e1e15407f60f32242027b7bb8170e58266/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L32).
- `"scheduler"` of class [`PNDMScheduler`].
- `"text_encoder"` of class `CLIPTextModel` as defined [in `transformers`](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTextModel).
- `"tokenizer"` of class `CLIPTokenizer` as defined [in `transformers`](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTokenizer).
- `"unet"` of class [`UNet2DConditionModel`].
- `"vae"` of class [`AutoencoderKL`].
Let's now compare the pipeline instance to the folder structure of the model repository `CompVis/ldm-text2im-large-256`. Looking at the folder structure of [`CompVis/ldm-text2im-large-256`](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main) on the Hub, we can see it matches 1-to-1 the printed out instance of `LDMTextToImagePipeline` above:
Let's now compare the pipeline instance to the folder structure of the model repository `runwayml/stable-diffusion-v1-5`. Looking at the folder structure of [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main) on the Hub and excluding model and saving format variants, we can see it matches 1-to-1 the printed out instance of `StableDiffusionPipeline` above:
```
.
├── bert
├── feature_extractor
│   └── preprocessor_config.json
├── model_index.json
├── safety_checker
│   ├── config.json
│   └── pytorch_model.bin
├── model_index.json
├── scheduler
│   └── scheduler_config.json
├── text_encoder
│   ├── config.json
│   └── pytorch_model.bin
├── tokenizer
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.txt
│   └── vocab.json
├── unet
│   ├── config.json
│   └── diffusion_pytorch_model.bin
└── vqvae
│   ── diffusion_pytorch_model.bin
└── vae
├── config.json
└── diffusion_pytorch_model.bin
── diffusion_pytorch_model.bin
```
As we can see each attribute of the instance of `LDMTextToImagePipeline` has its configuration and possibly weights defined in a subfolder that is called **exactly** like the class attribute (`"bert"`, `"scheduler"`, `"tokenizer"`, `"unet"`, `"vqvae"`). Importantly, every pipeline expects a `model_index.json` file that tells the `DiffusionPipeline` both:
Each attribute of the instance of `StableDiffusionPipeline` has its configuration and possibly weights defined in a subfolder that is called **exactly** like the class attribute (`"feature_extractor"`, `"safety_checker"`, `"scheduler"`, `"text_encoder"`, `"tokenizer"`, `"unet"`, `"vae"`). Importantly, every pipeline expects a `model_index.json` file that tells the `DiffusionPipeline` both:
- which pipeline class should be loaded, and
- what sub-classes from which library are stored in which subfolders
In the case of `CompVis/ldm-text2im-large-256` the `model_index.json` is therefore defined as follows:
In the case of `runwayml/stable-diffusion-v1-5` the `model_index.json` is therefore defined as follows:
```
{
"_class_name": "LDMTextToImagePipeline",
"_diffusers_version": "0.0.4",
"bert": [
"latent_diffusion",
"LDMBertModel"
"_class_name": "StableDiffusionPipeline",
"_diffusers_version": "0.6.0",
"feature_extractor": [
"transformers",
"CLIPFeatureExtractor"
],
"safety_checker": [
"stable_diffusion",
"StableDiffusionSafetyChecker"
],
"scheduler": [
"diffusers",
"DDIMScheduler"
"PNDMScheduler"
],
"text_encoder": [
"transformers",
"CLIPTextModel"
],
"tokenizer": [
"transformers",
"BertTokenizer"
"CLIPTokenizer"
],
"unet": [
"diffusers",
"UNet2DConditionModel"
],
"vqvae": [
"vae": [
"diffusers",
"AutoencoderKL"
]
@ -292,10 +531,36 @@ In the case of `CompVis/ldm-text2im-large-256` the `model_index.json` is therefo
"class"
]
```
- The `"name"` field corresponds both to the name of the subfolder in which the configuration and weights are stored as well as the attribute name of the pipeline class (as can be seen [here](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main/bert) and [here](https://github.com/huggingface/diffusers/blob/cd502b25cf0debac6f98d27a6638ef95208d1ea2/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py#L42)
- The `"name"` field corresponds both to the name of the subfolder in which the configuration and weights are stored as well as the attribute name of the pipeline class (as can be seen [here](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/bert) and [here](https://github.com/huggingface/diffusers/blob/cd502b25cf0debac6f98d27a6638ef95208d1ea2/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py#L42)
- The `"library"` field corresponds to the name of the library, *e.g.* `diffusers` or `transformers` from which the `"class"` should be loaded
- The `"class"` field corresponds to the name of the class, *e.g.* [`BertTokenizer`](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer) or [`UNet2DConditionModel`]
- The `"class"` field corresponds to the name of the class, *e.g.* [`CLIPTokenizer`](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTokenizer) or [`UNet2DConditionModel`]
<!--
TODO(Patrick) - Make sure to uncomment this part as soon as things are deprecated.
#### Using `revision` to load pipeline variants is deprecated
Previously the `revision` argument of [`DiffusionPipeline.from_pretrained`] was heavily used to
load model variants, e.g.:
```python
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16")
```
However, this behavior is now deprecated since the "revision" argument should (just as it's done in GitHub) better be used to load model checkpoints from a specific commit or branch in development.
The above example is therefore deprecated and won't be supported anymore for `diffusers >= 1.0.0`.
<Tip warning={true}>
If you load diffusers pipelines or models with `revision="fp16"` or `revision="non_ema"`,
please make sure to update to code and use `variant="fp16"` or `variation="non_ema"` respectively
instead.
</Tip>
-->
## Loading models
@ -310,19 +575,19 @@ Let's look at an example:
```python
from diffusers import UNet2DConditionModel
repo_id = "CompVis/ldm-text2im-large-256"
repo_id = "runwayml/stable-diffusion-v1-5"
model = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet")
```
Note how we have to define the `subfolder="unet"` argument to tell [`ModelMixin.from_pretrained`] that the model weights are located in a [subfolder of the repository](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main/unet).
Note how we have to define the `subfolder="unet"` argument to tell [`ModelMixin.from_pretrained`] that the model weights are located in a [subfolder of the repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/unet).
As explained in [Loading customized pipelines]("./using-diffusers/loading#loading-customized-pipelines"), one can pass a loaded model to a diffusion pipeline, via [`DiffusionPipeline.from_pretrained`]:
```python
from diffusers import DiffusionPipeline
repo_id = "CompVis/ldm-text2im-large-256"
ldm = DiffusionPipeline.from_pretrained(repo_id, unet=model)
repo_id = "runwayml/stable-diffusion-v1-5"
pipe = DiffusionPipeline.from_pretrained(repo_id, unet=model)
```
If the model files can be found directly at the root level, which is usually only the case for some very simple diffusion models, such as [`google/ddpm-cifar10-32`](https://huggingface.co/google/ddpm-cifar10-32), we don't
@ -335,6 +600,18 @@ repo_id = "google/ddpm-cifar10-32"
model = UNet2DModel.from_pretrained(repo_id)
```
As motivated in [How to save and load variants?](#how-to-save-and-load-variants), models can load and
save variants. To load a model variant, one should pass the `variant` function argument to [`ModelMixin.from_pretrained`]. Analogous, to save a model variant, one should pass the `variant` function argument to [`ModelMixin.save_pretrained`]:
```python
from diffusers import UNet2DConditionModel
model = UNet2DConditionModel.from_pretrained(
"diffusers/stable-diffusion-variants", subfolder="unet", variant="non_ema"
)
model.save_pretrained("./local-unet", variant="non_ema")
```
## Loading schedulers
Schedulers rely on [`SchedulerMixin.from_pretrained`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file.

View File

@ -16,18 +16,21 @@
import inspect
import os
import warnings
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from packaging import version
from requests import HTTPError
from torch import Tensor, device
from .. import __version__
from ..utils import (
CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE,
@ -89,12 +92,12 @@ def get_parameter_dtype(parameter: torch.nn.Module):
return first_tuple[1].dtype
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
try:
if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
return torch.load(checkpoint_file, map_location="cpu")
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
@ -141,6 +144,15 @@ def _load_state_dict_into_model(model_to_load, state_dict):
return error_msgs
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
splits = weights_name.split(".")
splits = splits[:-1] + [variant] + splits[-1:]
weights_name = ".".join(splits)
return weights_name
class ModelMixin(torch.nn.Module):
r"""
Base class for all models.
@ -250,6 +262,7 @@ class ModelMixin(torch.nn.Module):
is_main_process: bool = True,
save_function: Callable = None,
safe_serialization: bool = False,
variant: Optional[str] = None,
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
@ -268,6 +281,8 @@ class ModelMixin(torch.nn.Module):
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin.
"""
if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
@ -292,6 +307,7 @@ class ModelMixin(torch.nn.Module):
state_dict = model_to_save.state_dict()
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
# Save the model
save_function(state_dict, os.path.join(save_directory, weights_name))
@ -371,6 +387,9 @@ class ModelMixin(torch.nn.Module):
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.
variant (`str`, *optional*):
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_flax`.
<Tip>
@ -401,6 +420,7 @@ class ModelMixin(torch.nn.Module):
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
@ -488,7 +508,7 @@ class ModelMixin(torch.nn.Module):
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=SAFETENSORS_WEIGHTS_NAME,
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
@ -504,7 +524,7 @@ class ModelMixin(torch.nn.Module):
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=WEIGHTS_NAME,
weights_name=_add_variant(WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
@ -538,7 +558,7 @@ class ModelMixin(torch.nn.Module):
# if device_map is None, load the state dict and move the params from meta device to the cpu
if device_map is None:
param_device = "cpu"
state_dict = load_state_dict(model_file)
state_dict = load_state_dict(model_file, variant=variant)
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0:
@ -587,7 +607,7 @@ class ModelMixin(torch.nn.Module):
)
model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file)
state_dict = load_state_dict(model_file, variant=variant)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
@ -800,8 +820,38 @@ def _get_model_file(
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
)
else:
# 1. First check if deprecated way of loading from branches is used
if (
revision in DEPRECATED_REVISION_ARGS
and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
and version.parse(version.parse(__version__).base_version) >= version.parse("0.15.0")
):
try:
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=_add_variant(weights_name, revision),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
)
warnings.warn(
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
FutureWarning,
)
return model_file
except: # noqa: E722
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name)}' so that the correct variant file can be added.",
FutureWarning,
)
try:
# Load from URL or cache if already cached
# 2. Load model file as usual
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=weights_name,

View File

@ -17,6 +17,8 @@
import importlib
import inspect
import os
import re
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
@ -31,15 +33,16 @@ from tqdm.auto import tqdm
import diffusers
from .. import __version__
from ..configuration_utils import ConfigMixin
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import (
CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE,
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
BaseOutput,
deprecate,
@ -56,6 +59,11 @@ from ..utils import (
if is_transformers_available():
import transformers
from transformers import PreTrainedModel
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
from ..utils import FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
INDEX_FILE = "diffusion_pytorch_model.bin"
@ -120,15 +128,16 @@ class AudioPipelineOutput(BaseOutput):
audios: np.ndarray
def is_safetensors_compatible(info) -> bool:
filenames = set(sibling.rfilename for sibling in info.siblings)
def is_safetensors_compatible(filenames, variant=None) -> bool:
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
for pt_filename in pt_filenames:
_variant = f".{variant}" if (variant is not None and variant in pt_filename) else ""
prefix, raw = os.path.split(pt_filename)
if raw == "pytorch_model.bin":
if raw == f"pytorch_model{_variant}.bin":
# transformers specific
sf_filename = os.path.join(prefix, "model.safetensors")
sf_filename = os.path.join(prefix, f"model{_variant}.safetensors")
else:
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
if is_safetensors_compatible and sf_filename not in filenames:
@ -137,6 +146,41 @@ def is_safetensors_compatible(info) -> bool:
return is_safetensors_compatible
def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]:
filenames = set(sibling.rfilename for sibling in info.siblings)
weight_names = [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
# .bin, .safetensors, ...
weight_suffixs = [w.split(".")[-1] for w in weight_names]
variant_file_regex = (
re.compile(f"({'|'.join(weight_prefixes)})(.{variant}.)({'|'.join(weight_suffixs)})")
if variant is not None
else None
)
non_variant_file_regex = re.compile(f"{'|'.join(weight_names)}")
if variant is not None:
variant_filenames = set(f for f in filenames if variant_file_regex.match(f.split("/")[-1]) is not None)
else:
variant_filenames = set()
non_variant_filenames = set(f for f in filenames if non_variant_file_regex.match(f.split("/")[-1]) is not None)
usable_filenames = set(variant_filenames)
for f in non_variant_filenames:
variant_filename = f"{f.split('.')[0]}.{variant}.{f.split('.')[1]}"
if variant_filename not in usable_filenames:
usable_filenames.add(f)
return usable_filenames, variant_filenames
class DiffusionPipeline(ConfigMixin):
r"""
Base class for all models.
@ -194,6 +238,7 @@ class DiffusionPipeline(ConfigMixin):
self,
save_directory: Union[str, os.PathLike],
safe_serialization: bool = False,
variant: Optional[str] = None,
):
"""
Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
@ -205,6 +250,8 @@ class DiffusionPipeline(ConfigMixin):
Directory to which to save. Will be created if it doesn't exist.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin.
"""
self.save_config(save_directory)
@ -246,12 +293,15 @@ class DiffusionPipeline(ConfigMixin):
# Call the save method with the argument safe_serialization only if it's supported
save_method_signature = inspect.signature(save_method)
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
save_method_accept_variant = "variant" in save_method_signature.parameters
save_kwargs = {}
if save_method_accept_safe:
save_method(
os.path.join(save_directory, pipeline_component_name), safe_serialization=safe_serialization
)
else:
save_method(os.path.join(save_directory, pipeline_component_name))
save_kwargs["safe_serialization"] = safe_serialization
if save_method_accept_variant:
save_kwargs["variant"] = variant
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
if torch_device is None:
@ -403,6 +453,9 @@ class DiffusionPipeline(ConfigMixin):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
specific pipeline class. The overwritten components are then directly passed to the pipelines
`__init__` method. See example below for more information.
variant (`str`, *optional*):
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_flax`.
<Tip>
@ -454,6 +507,7 @@ class DiffusionPipeline(ConfigMixin):
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
return_cached_folder = kwargs.pop("return_cached_folder", False)
variant = kwargs.pop("variant", None)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
@ -468,28 +522,87 @@ class DiffusionPipeline(ConfigMixin):
use_auth_token=use_auth_token,
revision=revision,
)
# make sure we only download sub-folders and `diffusers` filenames
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [
WEIGHTS_NAME,
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
ONNX_WEIGHTS_NAME,
cls.config_name,
]
# make sure we don't download flax weights
ignore_patterns = ["*.msgpack"]
# retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
if not local_files_only:
info = model_info(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
revision=revision,
)
model_filenames, variant_filenames = variant_compatible_siblings(info, variant=variant)
model_folder_names = set([os.path.split(f)[0] for f in model_filenames])
if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version
) >= version.parse("0.10.0"):
info = model_info(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
revision=None,
)
comp_model_filenames, _ = variant_compatible_siblings(info, variant=revision)
comp_model_filenames = [
".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames
]
if set(comp_model_filenames) == set(model_filenames):
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{variant}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
FutureWarning,
)
else:
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
FutureWarning,
)
# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
# allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names]
# also allow downloading config.jsons with the model
allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names]
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors"]
allow_patterns += [
FLAX_WEIGHTS_NAME,
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
cls.config_name,
CUSTOM_PIPELINE_FILE_NAME,
]
if custom_pipeline is not None:
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", ".onnx"]
elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant):
ignore_patterns = ["*.bin", "*.msgpack"]
safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")])
safetensors_model_filenames = set([f for f in model_filenames if f.endswith(".safetensors")])
if (
len(safetensors_variant_filenames) > 0
and safetensors_model_filenames != safetensors_variant_filenames
):
logger.warn(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
else:
ignore_patterns = ["*.safetensors", "*.msgpack"]
bin_variant_filenames = set([f for f in variant_filenames if f.endswith(".bin")])
bin_model_filenames = set([f for f in model_filenames if f.endswith(".bin")])
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
logger.warn(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
else:
# allow everything since it has to be downloaded anyways
ignore_patterns = allow_patterns = None
if cls != DiffusionPipeline:
requested_pipeline_class = cls.__name__
@ -501,21 +614,6 @@ class DiffusionPipeline(ConfigMixin):
user_agent = http_user_agent(user_agent)
if is_safetensors_available() and not local_files_only:
info = model_info(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
revision=revision,
)
if is_safetensors_compatible(info):
ignore_patterns.append("*.bin")
else:
# as a safety mechanism we also don't download safetensors if
# not all safetensors files are there
ignore_patterns.append("*.safetensors")
else:
ignore_patterns.append("*.safetensors")
# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
@ -533,6 +631,16 @@ class DiffusionPipeline(ConfigMixin):
cached_folder = pretrained_model_name_or_path
config_dict = cls.load_config(cached_folder)
# retrieve which subfolders should load variants
model_variants = {}
if variant is not None:
for folder in os.listdir(cached_folder):
folder_path = os.path.join(cached_folder, folder)
is_folder = os.path.isdir(folder_path) and folder in config_dict
variant_exists = is_folder and any(path.split(".")[1] == variant for path in os.listdir(folder_path))
if variant_exists:
model_variants[folder] = variant
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
if custom_pipeline is not None:
@ -717,10 +825,11 @@ class DiffusionPipeline(ConfigMixin):
loading_kwargs["sess_options"] = sess_options
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
is_transformers_model = (
is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
and transformers_version >= version.parse("4.20.0")
)
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
@ -728,9 +837,23 @@ class DiffusionPipeline(ConfigMixin):
# This makes sure that the weights won't be initialized which significantly speeds up loading.
if is_diffusers_model or is_transformers_model:
loading_kwargs["device_map"] = device_map
loading_kwargs["variant"] = model_variants.pop(name, None)
if from_flax:
loading_kwargs["from_flax"] = True
# the following can be deleted once the minimum required `transformers` version
# is higher than 4.27
if (
is_transformers_model
and loading_kwargs["variant"] is not None
and transformers_version < version.parse("4.27.0")
):
raise ImportError(
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
)
elif is_transformers_model and loading_kwargs["variant"] is None:
loading_kwargs.pop("variant")
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
if not (from_flax and is_transformers_model):
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage

View File

@ -20,6 +20,7 @@ from packaging import version
from .. import __version__
from .constants import (
CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE,
DIFFUSERS_DYNAMIC_MODULE_NAME,
FLAX_WEIGHTS_NAME,

View File

@ -30,3 +30,4 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]

View File

@ -16,10 +16,12 @@
import inspect
import tempfile
import unittest
import unittest.mock as mock
from typing import Dict, List, Tuple
import numpy as np
import torch
from requests.exceptions import HTTPError
from diffusers.models import ModelMixin, UNet2DConditionModel
from diffusers.training_utils import EMAModel
@ -34,6 +36,30 @@ class ModelUtilsTest(unittest.TestCase):
# make sure that error message states what keys are missing
assert "conv_out.bias" in str(error_context.exception)
def test_cached_files_are_used_when_no_internet(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
orig_model = UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet"
)
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("requests.request", return_value=response_mock):
# Download this model to make sure it's in the cache.
model = UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", local_files_only=True
)
for p1, p2 in zip(orig_model.parameters(), model.parameters()):
if p1.data.ne(p2.data).sum() > 0:
assert False, "Parameters not the same!"
class ModelTesterMixin:
def test_from_save_pretrained(self):
@ -66,6 +92,44 @@ class ModelTesterMixin:
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
def test_from_save_pretrained_variant(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, variant="fp16")
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
# non-variant cannot be loaded
with self.assertRaises(OSError) as error_context:
self.model_class.from_pretrained(tmpdirname)
# make sure that error message states what keys are missing
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(error_context.exception)
new_model.to(torch_device)
with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
_ = model(**self.dummy_input)
_ = new_model(**self.dummy_input)
image = model(**inputs_dict)
if isinstance(image, dict):
image = image.sample
new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image.sample
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
def test_from_save_pretrained_dtype(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

View File

@ -21,6 +21,7 @@ import shutil
import sys
import tempfile
import unittest
import unittest.mock as mock
import numpy as np
import PIL
@ -28,6 +29,7 @@ import safetensors.torch
import torch
from parameterized import parameterized
from PIL import Image
from requests.exceptions import HTTPError
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
@ -166,6 +168,155 @@ class DownloadTests(unittest.TestCase):
assert np.max(np.abs(out - out_2)) < 1e-3
def test_cached_files_are_used_when_no_internet(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
orig_pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
orig_comps = {k: v for k, v in orig_pipe.components.items() if hasattr(v, "parameters")}
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("requests.request", return_value=response_mock):
# Download this model to make sure it's in the cache.
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None, local_files_only=True
)
comps = {k: v for k, v in pipe.components.items() if hasattr(v, "parameters")}
for m1, m2 in zip(orig_comps.values(), comps.values()):
for p1, p2 in zip(m1.parameters(), m2.parameters()):
if p1.data.ne(p2.data).sum() > 0:
assert False, "Parameters not the same!"
def test_download_from_variant_folder(self):
for safe_avail in [False, True]:
import diffusers
diffusers.utils.import_utils._safetensors_available = safe_avail
other_format = ".bin" if safe_avail else ".safetensors"
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname
)
all_root_files = [
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
assert not any(f.endswith(other_format) for f in files)
# no variants
assert not any(len(f.split(".")) == 3 for f in files)
diffusers.utils.import_utils._safetensors_available = True
def test_download_variant_all(self):
for safe_avail in [False, True]:
import diffusers
diffusers.utils.import_utils._safetensors_available = safe_avail
other_format = ".bin" if safe_avail else ".safetensors"
this_format = ".safetensors" if safe_avail else ".bin"
variant = "fp16"
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
)
all_root_files = [
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a non-variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# unet, vae, text_encoder, safety_checker
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 4
# all checkpoints should have variant ending
assert not any(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files)
assert not any(f.endswith(other_format) for f in files)
diffusers.utils.import_utils._safetensors_available = True
def test_download_variant_partly(self):
for safe_avail in [False, True]:
import diffusers
diffusers.utils.import_utils._safetensors_available = safe_avail
other_format = ".bin" if safe_avail else ".safetensors"
this_format = ".safetensors" if safe_avail else ".bin"
variant = "no_ema"
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
)
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots")
all_root_files = [t[-1] for t in os.walk(snapshots)]
files = [item for sublist in all_root_files for item in sublist]
unet_files = os.listdir(os.path.join(snapshots, os.listdir(snapshots)[0], "unet"))
# Some of the downloaded files should be a non-variant file, check:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# only unet has "no_ema" variant
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
# vae, safety_checker and text_encoder should have no variant
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
assert not any(f.endswith(other_format) for f in files)
diffusers.utils.import_utils._safetensors_available = True
def test_download_broken_variant(self):
for safe_avail in [False, True]:
import diffusers
diffusers.utils.import_utils._safetensors_available = safe_avail
# text encoder is missing no variant and "no_ema" variant weights, so the following can't work
for variant in [None, "no_ema"]:
with self.assertRaises(OSError) as error_context:
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-broken-variants",
cache_dir=tmpdirname,
variant=variant,
)
assert "Error no file name" in str(error_context.exception)
# text encoder has fp16 variants so we can load it
with tempfile.TemporaryDirectory() as tmpdirname:
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16"
)
assert pipe is not None
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots")
all_root_files = [t[-1] for t in os.walk(snapshots)]
files = [item for sublist in all_root_files for item in sublist]
# None of the downloaded files should be a non-variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# only unet has "no_ema" variant
diffusers.utils.import_utils._safetensors_available = True
class CustomPipelineTests(unittest.TestCase):
def test_load_custom_pipeline(self):