[Variant] Add "variant" as input kwarg so to have better UX when downloading no_ema or fp16 weights (#2305)
* [Variant] Add variant loading mechanism * clean * improve further * up * add tests * add some first tests * up * up * use path splittetx * add deprecate * deprecation warnings * improve docs * up * up * up * fix tests * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * correct code format * fix warning * finish * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * Update docs/source/en/using-diffusers/loading.mdx Co-authored-by: Suraj Patil <surajp815@gmail.com> * Apply suggestions from code review Co-authored-by: Will Berman <wlbberman@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com> * correct loading docs * finish --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Will Berman <wlbberman@gmail.com>
This commit is contained in:
parent
e3ddbe25ed
commit
e5810e686e
|
@ -23,31 +23,50 @@ In the following we explain in-detail how to easily load:
|
|||
|
||||
## Loading pipelines
|
||||
|
||||
The [`DiffusionPipeline`] class is the easiest way to access any diffusion model that is [available on the Hub](https://huggingface.co/models?library=diffusers). Let's look at an example on how to download [CompVis' Latent Diffusion model](https://huggingface.co/CompVis/ldm-text2im-large-256).
|
||||
The [`DiffusionPipeline`] class is the easiest way to access any diffusion model that is [available on the Hub](https://huggingface.co/models?library=diffusers). Let's look at an example on how to download [Runway's Stable Diffusion model](https://huggingface.co/runwayml/stable-diffusion-v1-5).
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
repo_id = "CompVis/ldm-text2im-large-256"
|
||||
ldm = DiffusionPipeline.from_pretrained(repo_id)
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
pipe = DiffusionPipeline.from_pretrained(repo_id)
|
||||
```
|
||||
|
||||
Here [`DiffusionPipeline`] automatically detects the correct pipeline (*i.e.* [`LDMTextToImagePipeline`]), downloads and caches all required configuration and weight files (if not already done so), and finally returns a pipeline instance, called `ldm`.
|
||||
The pipeline instance can then be called using [`LDMTextToImagePipeline.__call__`] (i.e., `ldm("image of a astronaut riding a horse")`) for text-to-image generation.
|
||||
Here [`DiffusionPipeline`] automatically detects the correct pipeline (*i.e.* [`StableDiffusionPipeline`]), downloads and caches all required configuration and weight files (if not already done so), and finally returns a pipeline instance, called `pipe`.
|
||||
The pipeline instance can then be called using [`StableDiffusionPipeline.__call__`] (i.e., `pipe("image of a astronaut riding a horse")`) for text-to-image generation.
|
||||
|
||||
Instead of using the generic [`DiffusionPipeline`] class for loading, you can also load the appropriate pipeline class directly. The code snippet above yields the same instance as when doing:
|
||||
|
||||
```python
|
||||
from diffusers import LDMTextToImagePipeline
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
repo_id = "CompVis/ldm-text2im-large-256"
|
||||
ldm = LDMTextToImagePipeline.from_pretrained(repo_id)
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(repo_id)
|
||||
```
|
||||
|
||||
Diffusion pipelines like `LDMTextToImagePipeline` often consist of multiple components. These components can be both parameterized models, such as `"unet"`, `"vqvae"` and "bert", tokenizers or schedulers. These components can interact in complex ways with each other when using the pipeline in inference, *e.g.* for [`LDMTextToImagePipeline`] or [`StableDiffusionPipeline`] the inference call is explained [here](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work).
|
||||
<Tip>
|
||||
|
||||
Many checkpoints, such as [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) and [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) can be used for multiple tasks, *e.g.* *text-to-image* or *image-to-image*.
|
||||
If you want to use those checkpoints for a task that is different from the default one, you have to load it directly from the corresponding task-specific pipeline class:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(repo_id)
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
Diffusion pipelines like `StableDiffusionPipeline` or `StableDiffusionImg2ImgPipeline` consist of multiple components. These components can be both parameterized models, such as `"unet"`, `"vae"` and `"text_encoder"`, tokenizers or schedulers.
|
||||
These components often interact in complex ways with each other when using the pipeline in inference, *e.g.* for [`StableDiffusionPipeline`] the inference call is explained [here](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work).
|
||||
The purpose of the [pipeline classes](./api/overview#diffusers-summary) is to wrap the complexity of these diffusion systems and give the user an easy-to-use API while staying flexible for customization, as will be shown later.
|
||||
|
||||
### Loading pipelines that require access request
|
||||
<!---
|
||||
THE FOLLOWING CAN BE UNCOMMENTED ONCE WE HAVE NEW MODELS WITH ACCESS REQUIREMENT
|
||||
|
||||
# Loading pipelines that require access request
|
||||
|
||||
Due to the capabilities of diffusion models to generate extremely realistic images, there is a certain danger that such models might be misused for unwanted applications, *e.g.* generating pornography or violent images.
|
||||
In order to minimize the possibility of such unsolicited use cases, some of the most powerful diffusion models require users to acknowledge a license before being able to use the model. If the user does not agree to the license, the pipeline cannot be downloaded.
|
||||
|
@ -94,6 +113,7 @@ stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, use_auth_token="<y
|
|||
```
|
||||
|
||||
The final option to use pipelines that require access without having to rely on the Hugging Face Hub is to load the pipeline locally as explained in the next section.
|
||||
-->
|
||||
|
||||
### Loading pipelines locally
|
||||
|
||||
|
@ -101,9 +121,9 @@ If you prefer to have complete control over the pipeline and its corresponding f
|
|||
we recommend loading pipelines locally.
|
||||
|
||||
To load a diffusion pipeline locally, you first need to manually download the whole folder structure on your local disk and then pass a local path to the [`DiffusionPipeline.from_pretrained`]. Let's again look at an example for
|
||||
[CompVis' Latent Diffusion model](https://huggingface.co/CompVis/ldm-text2im-large-256).
|
||||
[Runway's Stable Diffusion Diffusion model](https://huggingface.co/runwayml/stable-diffusion-v1-5).
|
||||
|
||||
First, you should make use of [`git-lfs`](https://git-lfs.github.com/) to download the whole folder structure that has been uploaded to the [model repository](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main):
|
||||
First, you should make use of [`git-lfs`](https://git-lfs.github.com/) to download the whole folder structure that has been uploaded to the [model repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main):
|
||||
|
||||
```
|
||||
git lfs install
|
||||
|
@ -178,105 +198,324 @@ stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(**components)
|
|||
|
||||
Note how the above code snippet makes use of [`DiffusionPipeline.components`].
|
||||
|
||||
### Loading variants
|
||||
|
||||
Diffusion Pipeline checkpoints can offer variants of the "main" diffusion pipeline checkpoint.
|
||||
Such checkpoint variants are usually variations of the checkpoint that have advantages for specific use-cases and that are so similar to the "main" checkpoint that they **should not** be put in a new checkpoint.
|
||||
A variation of a checkpoint has to have **exactly** the same serialization format and **exactly** the same model structure, including all weights having the same tensor shapes.
|
||||
|
||||
Examples of variations are different floating point types and non-ema weights. I.e. "fp16", "bf16", and "no_ema" are common variations.
|
||||
|
||||
#### Let's first talk about whats **not** checkpoint variant,
|
||||
|
||||
Checkpoint variants do **not** include different serialization formats (such as [safetensors](https://huggingface.co/docs/diffusers/main/en/using-diffusers/using_safetensors)) as weights in different serialization formats are
|
||||
identical to the weights of the "main" checkpoint, just loaded in a different framework.
|
||||
|
||||
Also variants do not correspond to different model structures, *e.g.* [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) is not a variant of [stable-diffusion-2-0](https://huggingface.co/stabilityai/stable-diffusion-2) since the model structure is different (Stable Diffusion 1-5 uses a different `CLIPTextModel` compared to Stable Diffusion 2.0).
|
||||
|
||||
Pipeline checkpoints that are identical in model structure, but have been trained on different datasets, trained with vastly different training setups and thus correspond to different official releases (such as [Stable Diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) and [Stable Diffusion v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)) should probably be stored in individual repositories instead of as variations of eachother.
|
||||
|
||||
#### So what are checkpoint variants then?
|
||||
|
||||
Checkpoint variants usually consist of the checkpoint stored in "*low-precision, low-storage*" dtype so that less bandwith is required to download them, or of *non-exponential-averaged* weights that shall be used when continuing fine-tuning from the checkpoint.
|
||||
Both use cases have clear advantages when their weights are considered variants: they share the same serialization format as the reference weights, and they correspond to a specialization of the "main" checkpoint which does not warrant a new model repository.
|
||||
A checkpoint stored in [torch's half-precision / float16 format](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/) requires only half the bandwith and storage when downloading the checkpoint,
|
||||
**but** cannot be used when continuing training or when running the checkpoint on CPU.
|
||||
Similarly the *non-exponential-averaged* (or non-EMA) version of the checkpoint should be used when continuing fine-tuning of the model checkpoint, **but** should not be used when using the checkpoint for inference.
|
||||
|
||||
#### How to save and load variants
|
||||
|
||||
Saving a diffusion pipeline as a variant can be done by providing [`DiffusionPipeline.save_pretrained`] with the `variant` argument.
|
||||
The `variant` extends the weight name by the provided variation, by changing the default weight name from `diffusion_pytorch_model.bin` to `diffusion_pytorch_model.{variant}.bin` or from `diffusion_pytorch_model.safetensors` to `diffusion_pytorch_model.{variant}.safetensors`. By doing so, one creates a variant of the pipeline checkpoint that can be loaded **instead** of the "main" pipeline checkpoint.
|
||||
|
||||
Let's have a look at how we could create a float16 variant of a pipeline. First, we load
|
||||
the "main" variant of a checkpoint (stored in `float32` precision) into mixed precision format, using `torch_dtype=torch.float16`.
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
Now all model components of the pipeline are stored in half-precision dtype. We can now save the
|
||||
pipeline under a `"fp16"` variant as follows:
|
||||
|
||||
```py
|
||||
pipe.save_pretrained("./stable-diffusion-v1-5", variant="fp16")
|
||||
```
|
||||
|
||||
If we don't save into an existing `stable-diffusion-v1-5` folder the new folder would look as follows:
|
||||
|
||||
```
|
||||
stable-diffusion-v1-5
|
||||
├── feature_extractor
|
||||
│ └── preprocessor_config.json
|
||||
├── model_index.json
|
||||
├── safety_checker
|
||||
│ ├── config.json
|
||||
│ └── pytorch_model.fp16.bin
|
||||
├── scheduler
|
||||
│ └── scheduler_config.json
|
||||
├── text_encoder
|
||||
│ ├── config.json
|
||||
│ └── pytorch_model.fp16.bin
|
||||
├── tokenizer
|
||||
│ ├── merges.txt
|
||||
│ ├── special_tokens_map.json
|
||||
│ ├── tokenizer_config.json
|
||||
│ └── vocab.json
|
||||
├── unet
|
||||
│ ├── config.json
|
||||
│ └── diffusion_pytorch_model.fp16.bin
|
||||
└── vae
|
||||
├── config.json
|
||||
└── diffusion_pytorch_model.fp16.bin
|
||||
```
|
||||
|
||||
As one can see, all model files now have a `.fp16.bin` extension instead of just `.bin`.
|
||||
The variant now has to be loaded by also passing a `variant="fp16"` to [`DiffusionPipeline.from_pretrained`], e.g.:
|
||||
|
||||
|
||||
```py
|
||||
DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
works just fine, while:
|
||||
|
||||
```py
|
||||
DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
throws an Exception:
|
||||
```
|
||||
OSError: Error no file named diffusion_pytorch_model.bin found in directory ./stable-diffusion-v1-45/vae since we **only** stored the model
|
||||
```
|
||||
|
||||
This is expected as we don't have any "non-variant" checkpoint files saved locally.
|
||||
However, the whole idea of pipeline variants is that they can co-exist with the "main" variant,
|
||||
so one would typically also save the "main" variant in the same folder. Let's do this:
|
||||
|
||||
```py
|
||||
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipe.save_pretrained("./stable-diffusion-v1-5")
|
||||
```
|
||||
|
||||
and upload the pipeline to the Hub under [diffusers/stable-diffusion-variants](https://huggingface.co/diffusers/stable-diffusion-variants).
|
||||
The file structure [on the Hub](https://huggingface.co/diffusers/stable-diffusion-variants/tree/main) now looks as follows:
|
||||
|
||||
```
|
||||
├── feature_extractor
|
||||
│ └── preprocessor_config.json
|
||||
├── model_index.json
|
||||
├── safety_checker
|
||||
│ ├── config.json
|
||||
│ ├── pytorch_model.bin
|
||||
│ └── pytorch_model.fp16.bin
|
||||
├── scheduler
|
||||
│ └── scheduler_config.json
|
||||
├── text_encoder
|
||||
│ ├── config.json
|
||||
│ ├── pytorch_model.bin
|
||||
│ └── pytorch_model.fp16.bin
|
||||
├── tokenizer
|
||||
│ ├── merges.txt
|
||||
│ ├── special_tokens_map.json
|
||||
│ ├── tokenizer_config.json
|
||||
│ └── vocab.json
|
||||
├── unet
|
||||
│ ├── config.json
|
||||
│ ├── diffusion_pytorch_model.bin
|
||||
│ ├── diffusion_pytorch_model.fp16.bin
|
||||
└── vae
|
||||
├── config.json
|
||||
├── diffusion_pytorch_model.bin
|
||||
└── diffusion_pytorch_model.fp16.bin
|
||||
```
|
||||
|
||||
We can now both download the "main" and the "fp16" variant from the Hub. Both:
|
||||
|
||||
```py
|
||||
pipe = DiffusionPipeline.from_pretrained("diffusers/stable-diffusion-variants")
|
||||
```
|
||||
|
||||
and
|
||||
|
||||
```py
|
||||
pipe = DiffusionPipeline.from_pretrained("diffusers/stable-diffusion-variants", variant="fp16")
|
||||
```
|
||||
|
||||
works.
|
||||
|
||||
<Tip>
|
||||
|
||||
Note that Diffusers never downloads more checkpoints than needed. E.g. when downloading
|
||||
the "main" variant, none of the "fp16.bin" files are downloaded and cached.
|
||||
Only when the user specifies `variant="fp16"` are those files downloaded and cached.
|
||||
|
||||
</Tip>
|
||||
|
||||
Finally, there are cases where only some of the checkpoint files of the pipeline are of a certain
|
||||
variation. E.g. it's usually only the UNet checkpoint that has both a *exponential-mean-averaged* (EMA) and a *non-exponential-mean-averaged* (non-EMA) version. All other model components, e.g. the text encoder, safety checker or variational auto-encoder usually don't have such a variation.
|
||||
In such a case, one would upload just the UNet's checkpoint file with a `non_ema` version format (as done [here](https://huggingface.co/diffusers/stable-diffusion-variants/blob/main/unet/diffusion_pytorch_model.non_ema.bin)) and upon calling:
|
||||
|
||||
```python
|
||||
pipe = DiffusionPipeline.from_pretrained("diffusers/stable-diffusion-variants", variant="non_ema")
|
||||
```
|
||||
|
||||
the model will use only the "non_ema" checkpoint variant if it is available - otherwise it'll load the
|
||||
"main" variation. In the above example, `variant="non_ema"` would therefore download the following file structure:
|
||||
|
||||
```
|
||||
├── feature_extractor
|
||||
│ └── preprocessor_config.json
|
||||
├── model_index.json
|
||||
├── safety_checker
|
||||
│ ├── config.json
|
||||
│ ├── pytorch_model.bin
|
||||
├── scheduler
|
||||
│ └── scheduler_config.json
|
||||
├── text_encoder
|
||||
│ ├── config.json
|
||||
│ ├── pytorch_model.bin
|
||||
├── tokenizer
|
||||
│ ├── merges.txt
|
||||
│ ├── special_tokens_map.json
|
||||
│ ├── tokenizer_config.json
|
||||
│ └── vocab.json
|
||||
├── unet
|
||||
│ ├── config.json
|
||||
│ └── diffusion_pytorch_model.non_ema.bin
|
||||
└── vae
|
||||
├── config.json
|
||||
├── diffusion_pytorch_model.bin
|
||||
```
|
||||
|
||||
In a nutshell, using `variant="{variant}"` will download all files that match the `{variant}` and if for a model component such a file variant is not present it will download the "main" variant. If neither a "main" or `{variant}` variant is available, an error will the thrown.
|
||||
|
||||
### How does loading work?
|
||||
|
||||
As a class method, [`DiffusionPipeline.from_pretrained`] is responsible for two things:
|
||||
- Download the latest version of the folder structure required to run the `repo_id` with `diffusers` and cache them. If the latest folder structure is available in the local cache, [`DiffusionPipeline.from_pretrained`] will simply reuse the cache and **not** re-download the files.
|
||||
- Load the cached weights into the _correct_ pipeline class – one of the [officially supported pipeline classes](./api/overview#diffusers-summary) - and return an instance of the class. The _correct_ pipeline class is thereby retrieved from the `model_index.json` file.
|
||||
|
||||
The underlying folder structure of diffusion pipelines correspond 1-to-1 to their corresponding class instances, *e.g.* [`LDMTextToImagePipeline`] for [`CompVis/ldm-text2im-large-256`](https://huggingface.co/CompVis/ldm-text2im-large-256)
|
||||
This can be understood better by looking at an example. Let's print out pipeline class instance `pipeline` we just defined:
|
||||
The underlying folder structure of diffusion pipelines correspond 1-to-1 to their corresponding class instances, *e.g.* [`StableDiffusionPipeline`] for [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
This can be better understood by looking at an example. Let's load a pipeline class instance `pipe` and print it:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
repo_id = "CompVis/ldm-text2im-large-256"
|
||||
ldm = DiffusionPipeline.from_pretrained(repo_id)
|
||||
print(ldm)
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
pipe = DiffusionPipeline.from_pretrained(repo_id)
|
||||
print(pipe)
|
||||
```
|
||||
|
||||
*Output*:
|
||||
```
|
||||
LDMTextToImagePipeline {
|
||||
"bert": [
|
||||
"latent_diffusion",
|
||||
"LDMBertModel"
|
||||
StableDiffusionPipeline {
|
||||
"feature_extractor": [
|
||||
"transformers",
|
||||
"CLIPFeatureExtractor"
|
||||
],
|
||||
"safety_checker": [
|
||||
"stable_diffusion",
|
||||
"StableDiffusionSafetyChecker"
|
||||
],
|
||||
"scheduler": [
|
||||
"diffusers",
|
||||
"DDIMScheduler"
|
||||
"PNDMScheduler"
|
||||
],
|
||||
"text_encoder": [
|
||||
"transformers",
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"tokenizer": [
|
||||
"transformers",
|
||||
"BertTokenizer"
|
||||
"CLIPTokenizer"
|
||||
],
|
||||
"unet": [
|
||||
"diffusers",
|
||||
"UNet2DConditionModel"
|
||||
],
|
||||
"vqvae": [
|
||||
"vae": [
|
||||
"diffusers",
|
||||
"AutoencoderKL"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
First, we see that the official pipeline is the [`LDMTextToImagePipeline`], and second we see that the `LDMTextToImagePipeline` consists of 5 components:
|
||||
- `"bert"` of class `LDMBertModel` as defined [in the pipeline](https://github.com/huggingface/diffusers/blob/cd502b25cf0debac6f98d27a6638ef95208d1ea2/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py#L664)
|
||||
- `"scheduler"` of class [`DDIMScheduler`]
|
||||
- `"tokenizer"` of class `BertTokenizer` as defined [in `transformers`](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer)
|
||||
- `"unet"` of class [`UNet2DConditionModel`]
|
||||
- `"vqvae"` of class [`AutoencoderKL`]
|
||||
First, we see that the official pipeline is the [`StableDiffusionPipeline`], and second we see that the `StableDiffusionPipeline` consists of 7 components:
|
||||
- `"feature_extractor"` of class `CLIPFeatureExtractor` as defined [in `transformers`](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPFeatureExtractor).
|
||||
- `"safety_checker"` as defined [here](https://github.com/huggingface/diffusers/blob/e55687e1e15407f60f32242027b7bb8170e58266/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L32).
|
||||
- `"scheduler"` of class [`PNDMScheduler`].
|
||||
- `"text_encoder"` of class `CLIPTextModel` as defined [in `transformers`](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTextModel).
|
||||
- `"tokenizer"` of class `CLIPTokenizer` as defined [in `transformers`](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
- `"unet"` of class [`UNet2DConditionModel`].
|
||||
- `"vae"` of class [`AutoencoderKL`].
|
||||
|
||||
Let's now compare the pipeline instance to the folder structure of the model repository `CompVis/ldm-text2im-large-256`. Looking at the folder structure of [`CompVis/ldm-text2im-large-256`](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main) on the Hub, we can see it matches 1-to-1 the printed out instance of `LDMTextToImagePipeline` above:
|
||||
Let's now compare the pipeline instance to the folder structure of the model repository `runwayml/stable-diffusion-v1-5`. Looking at the folder structure of [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main) on the Hub and excluding model and saving format variants, we can see it matches 1-to-1 the printed out instance of `StableDiffusionPipeline` above:
|
||||
|
||||
```
|
||||
.
|
||||
├── bert
|
||||
├── feature_extractor
|
||||
│ └── preprocessor_config.json
|
||||
├── model_index.json
|
||||
├── safety_checker
|
||||
│ ├── config.json
|
||||
│ └── pytorch_model.bin
|
||||
├── model_index.json
|
||||
├── scheduler
|
||||
│ └── scheduler_config.json
|
||||
├── text_encoder
|
||||
│ ├── config.json
|
||||
│ └── pytorch_model.bin
|
||||
├── tokenizer
|
||||
│ ├── merges.txt
|
||||
│ ├── special_tokens_map.json
|
||||
│ ├── tokenizer_config.json
|
||||
│ └── vocab.txt
|
||||
│ └── vocab.json
|
||||
├── unet
|
||||
│ ├── config.json
|
||||
│ └── diffusion_pytorch_model.bin
|
||||
└── vqvae
|
||||
│ ├── diffusion_pytorch_model.bin
|
||||
└── vae
|
||||
├── config.json
|
||||
└── diffusion_pytorch_model.bin
|
||||
├── diffusion_pytorch_model.bin
|
||||
```
|
||||
|
||||
As we can see each attribute of the instance of `LDMTextToImagePipeline` has its configuration and possibly weights defined in a subfolder that is called **exactly** like the class attribute (`"bert"`, `"scheduler"`, `"tokenizer"`, `"unet"`, `"vqvae"`). Importantly, every pipeline expects a `model_index.json` file that tells the `DiffusionPipeline` both:
|
||||
Each attribute of the instance of `StableDiffusionPipeline` has its configuration and possibly weights defined in a subfolder that is called **exactly** like the class attribute (`"feature_extractor"`, `"safety_checker"`, `"scheduler"`, `"text_encoder"`, `"tokenizer"`, `"unet"`, `"vae"`). Importantly, every pipeline expects a `model_index.json` file that tells the `DiffusionPipeline` both:
|
||||
- which pipeline class should be loaded, and
|
||||
- what sub-classes from which library are stored in which subfolders
|
||||
|
||||
In the case of `CompVis/ldm-text2im-large-256` the `model_index.json` is therefore defined as follows:
|
||||
In the case of `runwayml/stable-diffusion-v1-5` the `model_index.json` is therefore defined as follows:
|
||||
|
||||
```
|
||||
{
|
||||
"_class_name": "LDMTextToImagePipeline",
|
||||
"_diffusers_version": "0.0.4",
|
||||
"bert": [
|
||||
"latent_diffusion",
|
||||
"LDMBertModel"
|
||||
"_class_name": "StableDiffusionPipeline",
|
||||
"_diffusers_version": "0.6.0",
|
||||
"feature_extractor": [
|
||||
"transformers",
|
||||
"CLIPFeatureExtractor"
|
||||
],
|
||||
"safety_checker": [
|
||||
"stable_diffusion",
|
||||
"StableDiffusionSafetyChecker"
|
||||
],
|
||||
"scheduler": [
|
||||
"diffusers",
|
||||
"DDIMScheduler"
|
||||
"PNDMScheduler"
|
||||
],
|
||||
"text_encoder": [
|
||||
"transformers",
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"tokenizer": [
|
||||
"transformers",
|
||||
"BertTokenizer"
|
||||
"CLIPTokenizer"
|
||||
],
|
||||
"unet": [
|
||||
"diffusers",
|
||||
"UNet2DConditionModel"
|
||||
],
|
||||
"vqvae": [
|
||||
"vae": [
|
||||
"diffusers",
|
||||
"AutoencoderKL"
|
||||
]
|
||||
|
@ -292,10 +531,36 @@ In the case of `CompVis/ldm-text2im-large-256` the `model_index.json` is therefo
|
|||
"class"
|
||||
]
|
||||
```
|
||||
- The `"name"` field corresponds both to the name of the subfolder in which the configuration and weights are stored as well as the attribute name of the pipeline class (as can be seen [here](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main/bert) and [here](https://github.com/huggingface/diffusers/blob/cd502b25cf0debac6f98d27a6638ef95208d1ea2/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py#L42)
|
||||
- The `"name"` field corresponds both to the name of the subfolder in which the configuration and weights are stored as well as the attribute name of the pipeline class (as can be seen [here](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/bert) and [here](https://github.com/huggingface/diffusers/blob/cd502b25cf0debac6f98d27a6638ef95208d1ea2/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py#L42)
|
||||
- The `"library"` field corresponds to the name of the library, *e.g.* `diffusers` or `transformers` from which the `"class"` should be loaded
|
||||
- The `"class"` field corresponds to the name of the class, *e.g.* [`BertTokenizer`](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer) or [`UNet2DConditionModel`]
|
||||
- The `"class"` field corresponds to the name of the class, *e.g.* [`CLIPTokenizer`](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTokenizer) or [`UNet2DConditionModel`]
|
||||
|
||||
<!--
|
||||
TODO(Patrick) - Make sure to uncomment this part as soon as things are deprecated.
|
||||
|
||||
#### Using `revision` to load pipeline variants is deprecated
|
||||
|
||||
Previously the `revision` argument of [`DiffusionPipeline.from_pretrained`] was heavily used to
|
||||
load model variants, e.g.:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16")
|
||||
```
|
||||
|
||||
However, this behavior is now deprecated since the "revision" argument should (just as it's done in GitHub) better be used to load model checkpoints from a specific commit or branch in development.
|
||||
|
||||
The above example is therefore deprecated and won't be supported anymore for `diffusers >= 1.0.0`.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
If you load diffusers pipelines or models with `revision="fp16"` or `revision="non_ema"`,
|
||||
please make sure to update to code and use `variant="fp16"` or `variation="non_ema"` respectively
|
||||
instead.
|
||||
|
||||
</Tip>
|
||||
-->
|
||||
|
||||
## Loading models
|
||||
|
||||
|
@ -310,19 +575,19 @@ Let's look at an example:
|
|||
```python
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
repo_id = "CompVis/ldm-text2im-large-256"
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
model = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet")
|
||||
```
|
||||
|
||||
Note how we have to define the `subfolder="unet"` argument to tell [`ModelMixin.from_pretrained`] that the model weights are located in a [subfolder of the repository](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main/unet).
|
||||
Note how we have to define the `subfolder="unet"` argument to tell [`ModelMixin.from_pretrained`] that the model weights are located in a [subfolder of the repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/unet).
|
||||
|
||||
As explained in [Loading customized pipelines]("./using-diffusers/loading#loading-customized-pipelines"), one can pass a loaded model to a diffusion pipeline, via [`DiffusionPipeline.from_pretrained`]:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
repo_id = "CompVis/ldm-text2im-large-256"
|
||||
ldm = DiffusionPipeline.from_pretrained(repo_id, unet=model)
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
pipe = DiffusionPipeline.from_pretrained(repo_id, unet=model)
|
||||
```
|
||||
|
||||
If the model files can be found directly at the root level, which is usually only the case for some very simple diffusion models, such as [`google/ddpm-cifar10-32`](https://huggingface.co/google/ddpm-cifar10-32), we don't
|
||||
|
@ -335,6 +600,18 @@ repo_id = "google/ddpm-cifar10-32"
|
|||
model = UNet2DModel.from_pretrained(repo_id)
|
||||
```
|
||||
|
||||
As motivated in [How to save and load variants?](#how-to-save-and-load-variants), models can load and
|
||||
save variants. To load a model variant, one should pass the `variant` function argument to [`ModelMixin.from_pretrained`]. Analogous, to save a model variant, one should pass the `variant` function argument to [`ModelMixin.save_pretrained`]:
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
model = UNet2DConditionModel.from_pretrained(
|
||||
"diffusers/stable-diffusion-variants", subfolder="unet", variant="non_ema"
|
||||
)
|
||||
model.save_pretrained("./local-unet", variant="non_ema")
|
||||
```
|
||||
|
||||
## Loading schedulers
|
||||
|
||||
Schedulers rely on [`SchedulerMixin.from_pretrained`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file.
|
||||
|
|
|
@ -16,18 +16,21 @@
|
|||
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from packaging import version
|
||||
from requests import HTTPError
|
||||
from torch import Tensor, device
|
||||
|
||||
from .. import __version__
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
DIFFUSERS_CACHE,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
HF_HUB_OFFLINE,
|
||||
|
@ -89,12 +92,12 @@ def get_parameter_dtype(parameter: torch.nn.Module):
|
|||
return first_tuple[1].dtype
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
||||
"""
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
try:
|
||||
if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
|
||||
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
|
||||
return torch.load(checkpoint_file, map_location="cpu")
|
||||
else:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
|
@ -141,6 +144,15 @@ def _load_state_dict_into_model(model_to_load, state_dict):
|
|||
return error_msgs
|
||||
|
||||
|
||||
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
if variant is not None:
|
||||
splits = weights_name.split(".")
|
||||
splits = splits[:-1] + [variant] + splits[-1:]
|
||||
weights_name = ".".join(splits)
|
||||
|
||||
return weights_name
|
||||
|
||||
|
||||
class ModelMixin(torch.nn.Module):
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
@ -250,6 +262,7 @@ class ModelMixin(torch.nn.Module):
|
|||
is_main_process: bool = True,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
|
@ -268,6 +281,8 @@ class ModelMixin(torch.nn.Module):
|
|||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `False`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
||||
"""
|
||||
if safe_serialization and not is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
||||
|
@ -292,6 +307,7 @@ class ModelMixin(torch.nn.Module):
|
|||
state_dict = model_to_save.state_dict()
|
||||
|
||||
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
weights_name = _add_variant(weights_name, variant)
|
||||
|
||||
# Save the model
|
||||
save_function(state_dict, os.path.join(save_directory, weights_name))
|
||||
|
@ -371,6 +387,9 @@ class ModelMixin(torch.nn.Module):
|
|||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||
setting this argument to `True` will raise an error.
|
||||
variant (`str`, *optional*):
|
||||
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
||||
ignored when using `from_flax`.
|
||||
|
||||
<Tip>
|
||||
|
||||
|
@ -401,6 +420,7 @@ class ModelMixin(torch.nn.Module):
|
|||
subfolder = kwargs.pop("subfolder", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
|
@ -488,7 +508,7 @@ class ModelMixin(torch.nn.Module):
|
|||
try:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=SAFETENSORS_WEIGHTS_NAME,
|
||||
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
|
@ -504,7 +524,7 @@ class ModelMixin(torch.nn.Module):
|
|||
if model_file is None:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=WEIGHTS_NAME,
|
||||
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
|
@ -538,7 +558,7 @@ class ModelMixin(torch.nn.Module):
|
|||
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
state_dict = load_state_dict(model_file)
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
# move the params from meta device to cpu
|
||||
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
||||
if len(missing_keys) > 0:
|
||||
|
@ -587,7 +607,7 @@ class ModelMixin(torch.nn.Module):
|
|||
)
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
state_dict = load_state_dict(model_file)
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
|
@ -800,8 +820,38 @@ def _get_model_file(
|
|||
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
else:
|
||||
# 1. First check if deprecated way of loading from branches is used
|
||||
if (
|
||||
revision in DEPRECATED_REVISION_ARGS
|
||||
and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
|
||||
and version.parse(version.parse(__version__).base_version) >= version.parse("0.15.0")
|
||||
):
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
filename=_add_variant(weights_name, revision),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
)
|
||||
warnings.warn(
|
||||
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
return model_file
|
||||
except: # noqa: E722
|
||||
warnings.warn(
|
||||
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name)}' so that the correct variant file can be added.",
|
||||
FutureWarning,
|
||||
)
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
# 2. Load model file as usual
|
||||
model_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
filename=weights_name,
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
@ -31,15 +33,16 @@ from tqdm.auto import tqdm
|
|||
|
||||
import diffusers
|
||||
|
||||
from .. import __version__
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
DIFFUSERS_CACHE,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
HF_HUB_OFFLINE,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
BaseOutput,
|
||||
deprecate,
|
||||
|
@ -56,6 +59,11 @@ from ..utils import (
|
|||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||
|
||||
from ..utils import FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
|
||||
|
||||
|
||||
INDEX_FILE = "diffusion_pytorch_model.bin"
|
||||
|
@ -120,15 +128,16 @@ class AudioPipelineOutput(BaseOutput):
|
|||
audios: np.ndarray
|
||||
|
||||
|
||||
def is_safetensors_compatible(info) -> bool:
|
||||
filenames = set(sibling.rfilename for sibling in info.siblings)
|
||||
def is_safetensors_compatible(filenames, variant=None) -> bool:
|
||||
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
|
||||
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
|
||||
|
||||
for pt_filename in pt_filenames:
|
||||
_variant = f".{variant}" if (variant is not None and variant in pt_filename) else ""
|
||||
prefix, raw = os.path.split(pt_filename)
|
||||
if raw == "pytorch_model.bin":
|
||||
if raw == f"pytorch_model{_variant}.bin":
|
||||
# transformers specific
|
||||
sf_filename = os.path.join(prefix, "model.safetensors")
|
||||
sf_filename = os.path.join(prefix, f"model{_variant}.safetensors")
|
||||
else:
|
||||
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
|
||||
if is_safetensors_compatible and sf_filename not in filenames:
|
||||
|
@ -137,6 +146,41 @@ def is_safetensors_compatible(info) -> bool:
|
|||
return is_safetensors_compatible
|
||||
|
||||
|
||||
def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]:
|
||||
filenames = set(sibling.rfilename for sibling in info.siblings)
|
||||
weight_names = [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
# model_pytorch, diffusion_model_pytorch, ...
|
||||
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
||||
# .bin, .safetensors, ...
|
||||
weight_suffixs = [w.split(".")[-1] for w in weight_names]
|
||||
|
||||
variant_file_regex = (
|
||||
re.compile(f"({'|'.join(weight_prefixes)})(.{variant}.)({'|'.join(weight_suffixs)})")
|
||||
if variant is not None
|
||||
else None
|
||||
)
|
||||
non_variant_file_regex = re.compile(f"{'|'.join(weight_names)}")
|
||||
|
||||
if variant is not None:
|
||||
variant_filenames = set(f for f in filenames if variant_file_regex.match(f.split("/")[-1]) is not None)
|
||||
else:
|
||||
variant_filenames = set()
|
||||
|
||||
non_variant_filenames = set(f for f in filenames if non_variant_file_regex.match(f.split("/")[-1]) is not None)
|
||||
|
||||
usable_filenames = set(variant_filenames)
|
||||
for f in non_variant_filenames:
|
||||
variant_filename = f"{f.split('.')[0]}.{variant}.{f.split('.')[1]}"
|
||||
if variant_filename not in usable_filenames:
|
||||
usable_filenames.add(f)
|
||||
|
||||
return usable_filenames, variant_filenames
|
||||
|
||||
|
||||
class DiffusionPipeline(ConfigMixin):
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
@ -194,6 +238,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
safe_serialization: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
|
||||
|
@ -205,6 +250,8 @@ class DiffusionPipeline(ConfigMixin):
|
|||
Directory to which to save. Will be created if it doesn't exist.
|
||||
safe_serialization (`bool`, *optional*, defaults to `False`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
||||
"""
|
||||
self.save_config(save_directory)
|
||||
|
||||
|
@ -246,12 +293,15 @@ class DiffusionPipeline(ConfigMixin):
|
|||
# Call the save method with the argument safe_serialization only if it's supported
|
||||
save_method_signature = inspect.signature(save_method)
|
||||
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
||||
save_method_accept_variant = "variant" in save_method_signature.parameters
|
||||
|
||||
save_kwargs = {}
|
||||
if save_method_accept_safe:
|
||||
save_method(
|
||||
os.path.join(save_directory, pipeline_component_name), safe_serialization=safe_serialization
|
||||
)
|
||||
else:
|
||||
save_method(os.path.join(save_directory, pipeline_component_name))
|
||||
save_kwargs["safe_serialization"] = safe_serialization
|
||||
if save_method_accept_variant:
|
||||
save_kwargs["variant"] = variant
|
||||
|
||||
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
|
||||
if torch_device is None:
|
||||
|
@ -403,6 +453,9 @@ class DiffusionPipeline(ConfigMixin):
|
|||
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
||||
specific pipeline class. The overwritten components are then directly passed to the pipelines
|
||||
`__init__` method. See example below for more information.
|
||||
variant (`str`, *optional*):
|
||||
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
||||
ignored when using `from_flax`.
|
||||
|
||||
<Tip>
|
||||
|
||||
|
@ -454,6 +507,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||
device_map = kwargs.pop("device_map", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
return_cached_folder = kwargs.pop("return_cached_folder", False)
|
||||
variant = kwargs.pop("variant", None)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
|
@ -468,28 +522,87 @@ class DiffusionPipeline(ConfigMixin):
|
|||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
)
|
||||
# make sure we only download sub-folders and `diffusers` filenames
|
||||
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
|
||||
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
||||
allow_patterns += [
|
||||
WEIGHTS_NAME,
|
||||
SCHEDULER_CONFIG_NAME,
|
||||
CONFIG_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
cls.config_name,
|
||||
]
|
||||
|
||||
# make sure we don't download flax weights
|
||||
ignore_patterns = ["*.msgpack"]
|
||||
# retrieve all folder_names that contain relevant files
|
||||
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
|
||||
|
||||
if not local_files_only:
|
||||
info = model_info(
|
||||
pretrained_model_name_or_path,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
)
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(info, variant=variant)
|
||||
model_folder_names = set([os.path.split(f)[0] for f in model_filenames])
|
||||
|
||||
if revision in DEPRECATED_REVISION_ARGS and version.parse(
|
||||
version.parse(__version__).base_version
|
||||
) >= version.parse("0.10.0"):
|
||||
info = model_info(
|
||||
pretrained_model_name_or_path,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=None,
|
||||
)
|
||||
comp_model_filenames, _ = variant_compatible_siblings(info, variant=revision)
|
||||
comp_model_filenames = [
|
||||
".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames
|
||||
]
|
||||
|
||||
if set(comp_model_filenames) == set(model_filenames):
|
||||
warnings.warn(
|
||||
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{variant}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
# all filenames compatible with variant will be added
|
||||
allow_patterns = list(model_filenames)
|
||||
|
||||
# allow all patterns from non-model folders
|
||||
# this enables downloading schedulers, tokenizers, ...
|
||||
allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names]
|
||||
# also allow downloading config.jsons with the model
|
||||
allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names]
|
||||
|
||||
if from_flax:
|
||||
ignore_patterns = ["*.bin", "*.safetensors"]
|
||||
allow_patterns += [
|
||||
FLAX_WEIGHTS_NAME,
|
||||
SCHEDULER_CONFIG_NAME,
|
||||
CONFIG_NAME,
|
||||
cls.config_name,
|
||||
CUSTOM_PIPELINE_FILE_NAME,
|
||||
]
|
||||
|
||||
if custom_pipeline is not None:
|
||||
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
|
||||
if from_flax:
|
||||
ignore_patterns = ["*.bin", "*.safetensors", ".onnx"]
|
||||
elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant):
|
||||
ignore_patterns = ["*.bin", "*.msgpack"]
|
||||
|
||||
safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")])
|
||||
safetensors_model_filenames = set([f for f in model_filenames if f.endswith(".safetensors")])
|
||||
if (
|
||||
len(safetensors_variant_filenames) > 0
|
||||
and safetensors_model_filenames != safetensors_variant_filenames
|
||||
):
|
||||
logger.warn(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
||||
)
|
||||
|
||||
else:
|
||||
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
||||
|
||||
bin_variant_filenames = set([f for f in variant_filenames if f.endswith(".bin")])
|
||||
bin_model_filenames = set([f for f in model_filenames if f.endswith(".bin")])
|
||||
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
||||
logger.warn(
|
||||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
||||
)
|
||||
|
||||
else:
|
||||
# allow everything since it has to be downloaded anyways
|
||||
ignore_patterns = allow_patterns = None
|
||||
|
||||
if cls != DiffusionPipeline:
|
||||
requested_pipeline_class = cls.__name__
|
||||
|
@ -501,21 +614,6 @@ class DiffusionPipeline(ConfigMixin):
|
|||
|
||||
user_agent = http_user_agent(user_agent)
|
||||
|
||||
if is_safetensors_available() and not local_files_only:
|
||||
info = model_info(
|
||||
pretrained_model_name_or_path,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
)
|
||||
if is_safetensors_compatible(info):
|
||||
ignore_patterns.append("*.bin")
|
||||
else:
|
||||
# as a safety mechanism we also don't download safetensors if
|
||||
# not all safetensors files are there
|
||||
ignore_patterns.append("*.safetensors")
|
||||
else:
|
||||
ignore_patterns.append("*.safetensors")
|
||||
|
||||
# download all allow_patterns
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
|
@ -533,6 +631,16 @@ class DiffusionPipeline(ConfigMixin):
|
|||
cached_folder = pretrained_model_name_or_path
|
||||
config_dict = cls.load_config(cached_folder)
|
||||
|
||||
# retrieve which subfolders should load variants
|
||||
model_variants = {}
|
||||
if variant is not None:
|
||||
for folder in os.listdir(cached_folder):
|
||||
folder_path = os.path.join(cached_folder, folder)
|
||||
is_folder = os.path.isdir(folder_path) and folder in config_dict
|
||||
variant_exists = is_folder and any(path.split(".")[1] == variant for path in os.listdir(folder_path))
|
||||
if variant_exists:
|
||||
model_variants[folder] = variant
|
||||
|
||||
# 2. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
if custom_pipeline is not None:
|
||||
|
@ -717,10 +825,11 @@ class DiffusionPipeline(ConfigMixin):
|
|||
loading_kwargs["sess_options"] = sess_options
|
||||
|
||||
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedModel)
|
||||
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
||||
|
@ -728,9 +837,23 @@ class DiffusionPipeline(ConfigMixin):
|
|||
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
||||
if is_diffusers_model or is_transformers_model:
|
||||
loading_kwargs["device_map"] = device_map
|
||||
loading_kwargs["variant"] = model_variants.pop(name, None)
|
||||
if from_flax:
|
||||
loading_kwargs["from_flax"] = True
|
||||
|
||||
# the following can be deleted once the minimum required `transformers` version
|
||||
# is higher than 4.27
|
||||
if (
|
||||
is_transformers_model
|
||||
and loading_kwargs["variant"] is not None
|
||||
and transformers_version < version.parse("4.27.0")
|
||||
):
|
||||
raise ImportError(
|
||||
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
|
||||
)
|
||||
elif is_transformers_model and loading_kwargs["variant"] is None:
|
||||
loading_kwargs.pop("variant")
|
||||
|
||||
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
|
||||
if not (from_flax and is_transformers_model):
|
||||
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
|
|
@ -20,6 +20,7 @@ from packaging import version
|
|||
from .. import __version__
|
||||
from .constants import (
|
||||
CONFIG_NAME,
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
DIFFUSERS_CACHE,
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
|
|
|
@ -30,3 +30,4 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
|||
DIFFUSERS_CACHE = default_cache_path
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
||||
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
||||
|
|
|
@ -16,10 +16,12 @@
|
|||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from diffusers.models import ModelMixin, UNet2DConditionModel
|
||||
from diffusers.training_utils import EMAModel
|
||||
|
@ -34,6 +36,30 @@ class ModelUtilsTest(unittest.TestCase):
|
|||
# make sure that error message states what keys are missing
|
||||
assert "conv_out.bias" in str(error_context.exception)
|
||||
|
||||
def test_cached_files_are_used_when_no_internet(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = HTTPError
|
||||
response_mock.json.return_value = {}
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
orig_model = UNet2DConditionModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet"
|
||||
)
|
||||
|
||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||
with mock.patch("requests.request", return_value=response_mock):
|
||||
# Download this model to make sure it's in the cache.
|
||||
model = UNet2DConditionModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", local_files_only=True
|
||||
)
|
||||
|
||||
for p1, p2 in zip(orig_model.parameters(), model.parameters()):
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
assert False, "Parameters not the same!"
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
def test_from_save_pretrained(self):
|
||||
|
@ -66,6 +92,44 @@ class ModelTesterMixin:
|
|||
max_diff = (image - new_image).abs().sum().item()
|
||||
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
|
||||
|
||||
def test_from_save_pretrained_variant(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, variant="fp16")
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
|
||||
|
||||
# non-variant cannot be loaded
|
||||
with self.assertRaises(OSError) as error_context:
|
||||
self.model_class.from_pretrained(tmpdirname)
|
||||
|
||||
# make sure that error message states what keys are missing
|
||||
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(error_context.exception)
|
||||
|
||||
new_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps" and isinstance(model, ModelMixin):
|
||||
_ = model(**self.dummy_input)
|
||||
_ = new_model(**self.dummy_input)
|
||||
|
||||
image = model(**inputs_dict)
|
||||
if isinstance(image, dict):
|
||||
image = image.sample
|
||||
|
||||
new_image = new_model(**inputs_dict)
|
||||
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.sample
|
||||
|
||||
max_diff = (image - new_image).abs().sum().item()
|
||||
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
|
||||
|
||||
def test_from_save_pretrained_dtype(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ import shutil
|
|||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
|
@ -28,6 +29,7 @@ import safetensors.torch
|
|||
import torch
|
||||
from parameterized import parameterized
|
||||
from PIL import Image
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
|
@ -166,6 +168,155 @@ class DownloadTests(unittest.TestCase):
|
|||
|
||||
assert np.max(np.abs(out - out_2)) < 1e-3
|
||||
|
||||
def test_cached_files_are_used_when_no_internet(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = HTTPError
|
||||
response_mock.json.return_value = {}
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
orig_pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
||||
)
|
||||
orig_comps = {k: v for k, v in orig_pipe.components.items() if hasattr(v, "parameters")}
|
||||
|
||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||
with mock.patch("requests.request", return_value=response_mock):
|
||||
# Download this model to make sure it's in the cache.
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None, local_files_only=True
|
||||
)
|
||||
comps = {k: v for k, v in pipe.components.items() if hasattr(v, "parameters")}
|
||||
|
||||
for m1, m2 in zip(orig_comps.values(), comps.values()):
|
||||
for p1, p2 in zip(m1.parameters(), m2.parameters()):
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
assert False, "Parameters not the same!"
|
||||
|
||||
def test_download_from_variant_folder(self):
|
||||
for safe_avail in [False, True]:
|
||||
import diffusers
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = safe_avail
|
||||
|
||||
other_format = ".bin" if safe_avail else ".safetensors"
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname
|
||||
)
|
||||
all_root_files = [
|
||||
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
|
||||
]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# None of the downloaded files should be a variant file even if we have some here:
|
||||
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
|
||||
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
|
||||
assert not any(f.endswith(other_format) for f in files)
|
||||
# no variants
|
||||
assert not any(len(f.split(".")) == 3 for f in files)
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
def test_download_variant_all(self):
|
||||
for safe_avail in [False, True]:
|
||||
import diffusers
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = safe_avail
|
||||
|
||||
other_format = ".bin" if safe_avail else ".safetensors"
|
||||
this_format = ".safetensors" if safe_avail else ".bin"
|
||||
variant = "fp16"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
|
||||
)
|
||||
all_root_files = [
|
||||
t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))
|
||||
]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# None of the downloaded files should be a non-variant file even if we have some here:
|
||||
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
|
||||
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
|
||||
# unet, vae, text_encoder, safety_checker
|
||||
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 4
|
||||
# all checkpoints should have variant ending
|
||||
assert not any(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files)
|
||||
assert not any(f.endswith(other_format) for f in files)
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
def test_download_variant_partly(self):
|
||||
for safe_avail in [False, True]:
|
||||
import diffusers
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = safe_avail
|
||||
|
||||
other_format = ".bin" if safe_avail else ".safetensors"
|
||||
this_format = ".safetensors" if safe_avail else ".bin"
|
||||
variant = "no_ema"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
|
||||
)
|
||||
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots")
|
||||
all_root_files = [t[-1] for t in os.walk(snapshots)]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
unet_files = os.listdir(os.path.join(snapshots, os.listdir(snapshots)[0], "unet"))
|
||||
|
||||
# Some of the downloaded files should be a non-variant file, check:
|
||||
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
|
||||
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
|
||||
# only unet has "no_ema" variant
|
||||
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
|
||||
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
|
||||
# vae, safety_checker and text_encoder should have no variant
|
||||
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
|
||||
assert not any(f.endswith(other_format) for f in files)
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
def test_download_broken_variant(self):
|
||||
for safe_avail in [False, True]:
|
||||
import diffusers
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = safe_avail
|
||||
# text encoder is missing no variant and "no_ema" variant weights, so the following can't work
|
||||
for variant in [None, "no_ema"]:
|
||||
with self.assertRaises(OSError) as error_context:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/stable-diffusion-broken-variants",
|
||||
cache_dir=tmpdirname,
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
assert "Error no file name" in str(error_context.exception)
|
||||
|
||||
# text encoder has fp16 variants so we can load it
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16"
|
||||
)
|
||||
assert pipe is not None
|
||||
|
||||
snapshots = os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots")
|
||||
all_root_files = [t[-1] for t in os.walk(snapshots)]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# None of the downloaded files should be a non-variant file even if we have some here:
|
||||
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
|
||||
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
|
||||
# only unet has "no_ema" variant
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
|
||||
class CustomPipelineTests(unittest.TestCase):
|
||||
def test_load_custom_pipeline(self):
|
||||
|
|
Loading…
Reference in New Issue