[Better scheduler docs] Improve usage examples of schedulers (#890)
* [Better scheduler docs] Improve usage examples of schedulers * finish * fix warnings and add test * finish * more replacements * adapt fast tests hf token * correct more * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Integrate compatibility with euler Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
a1ea8c01c3
commit
c18941b01a
|
@ -42,6 +42,8 @@ jobs:
|
||||||
python utils/print_env.py
|
python utils/print_env.py
|
||||||
|
|
||||||
- name: Run all fast tests on CPU
|
- name: Run all fast tests on CPU
|
||||||
|
env:
|
||||||
|
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||||
run: |
|
run: |
|
||||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_cpu tests/
|
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_cpu tests/
|
||||||
|
|
||||||
|
@ -91,6 +93,8 @@ jobs:
|
||||||
|
|
||||||
- name: Run all fast tests on MPS
|
- name: Run all fast tests on MPS
|
||||||
shell: arch -arch arm64 bash {0}
|
shell: arch -arch arm64 bash {0}
|
||||||
|
env:
|
||||||
|
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||||
run: |
|
run: |
|
||||||
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/
|
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/
|
||||||
|
|
||||||
|
|
|
@ -142,11 +142,7 @@ it before the pipeline and pass it to `from_pretrained`.
|
||||||
```python
|
```python
|
||||||
from diffusers import LMSDiscreteScheduler
|
from diffusers import LMSDiscreteScheduler
|
||||||
|
|
||||||
lms = LMSDiscreteScheduler(
|
lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||||
beta_start=0.00085,
|
|
||||||
beta_end=0.012,
|
|
||||||
beta_schedule="scaled_linear"
|
|
||||||
)
|
|
||||||
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
"runwayml/stable-diffusion-v1-5",
|
"runwayml/stable-diffusion-v1-5",
|
||||||
|
|
|
@ -121,7 +121,7 @@ you could use it as follows:
|
||||||
```python
|
```python
|
||||||
>>> from diffusers import LMSDiscreteScheduler
|
>>> from diffusers import LMSDiscreteScheduler
|
||||||
|
|
||||||
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
>>> scheduler = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||||
|
|
||||||
>>> generator = StableDiffusionPipeline.from_pretrained(
|
>>> generator = StableDiffusionPipeline.from_pretrained(
|
||||||
... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN
|
... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN
|
||||||
|
|
|
@ -469,9 +469,7 @@ def main(args):
|
||||||
eps=args.adam_epsilon,
|
eps=args.adam_epsilon,
|
||||||
)
|
)
|
||||||
|
|
||||||
noise_scheduler = DDPMScheduler(
|
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
|
|
||||||
)
|
|
||||||
|
|
||||||
train_dataset = DreamBoothDataset(
|
train_dataset = DreamBoothDataset(
|
||||||
instance_data_root=args.instance_data_dir,
|
instance_data_root=args.instance_data_dir,
|
||||||
|
|
|
@ -372,11 +372,7 @@ def main():
|
||||||
weight_decay=args.adam_weight_decay,
|
weight_decay=args.adam_weight_decay,
|
||||||
eps=args.adam_epsilon,
|
eps=args.adam_epsilon,
|
||||||
)
|
)
|
||||||
|
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||||
# TODO (patil-suraj): load scheduler using args
|
|
||||||
noise_scheduler = DDPMScheduler(
|
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
||||||
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
||||||
|
@ -609,9 +605,7 @@ def main():
|
||||||
vae=vae,
|
vae=vae,
|
||||||
unet=unet,
|
unet=unet,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
scheduler=PNDMScheduler(
|
scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
|
||||||
),
|
|
||||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
||||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||||
)
|
)
|
||||||
|
|
|
@ -419,13 +419,7 @@ def main():
|
||||||
eps=args.adam_epsilon,
|
eps=args.adam_epsilon,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO (patil-suraj): load scheduler using args
|
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||||
noise_scheduler = DDPMScheduler(
|
|
||||||
beta_start=0.00085,
|
|
||||||
beta_end=0.012,
|
|
||||||
beta_schedule="scaled_linear",
|
|
||||||
num_train_timesteps=1000,
|
|
||||||
)
|
|
||||||
|
|
||||||
train_dataset = TextualInversionDataset(
|
train_dataset = TextualInversionDataset(
|
||||||
data_root=args.train_data_dir,
|
data_root=args.train_data_dir,
|
||||||
|
@ -558,9 +552,7 @@ def main():
|
||||||
vae=vae,
|
vae=vae,
|
||||||
unet=unet,
|
unet=unet,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
scheduler=PNDMScheduler(
|
scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
|
||||||
),
|
|
||||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
||||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
""" ConfigMixin base class and utilities."""
|
""" ConfigMixin base class and utilities."""
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import functools
|
import functools
|
||||||
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
@ -48,9 +49,13 @@ class ConfigMixin:
|
||||||
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
||||||
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
||||||
overridden by parent class).
|
overridden by parent class).
|
||||||
|
- **_compatible_classes** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
|
||||||
|
`from_config` can be used from a class different than the one used to save the config (should be overridden
|
||||||
|
by parent class).
|
||||||
"""
|
"""
|
||||||
config_name = None
|
config_name = None
|
||||||
ignore_for_config = []
|
ignore_for_config = []
|
||||||
|
_compatible_classes = []
|
||||||
|
|
||||||
def register_to_config(self, **kwargs):
|
def register_to_config(self, **kwargs):
|
||||||
if self.config_name is None:
|
if self.config_name is None:
|
||||||
|
@ -280,9 +285,14 @@ class ConfigMixin:
|
||||||
|
|
||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_init_keys(cls):
|
||||||
|
return set(dict(inspect.signature(cls.__init__).parameters).keys())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_init_dict(cls, config_dict, **kwargs):
|
def extract_init_dict(cls, config_dict, **kwargs):
|
||||||
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
|
# 1. Retrieve expected config attributes from __init__ signature
|
||||||
|
expected_keys = cls._get_init_keys(cls)
|
||||||
expected_keys.remove("self")
|
expected_keys.remove("self")
|
||||||
# remove general kwargs if present in dict
|
# remove general kwargs if present in dict
|
||||||
if "kwargs" in expected_keys:
|
if "kwargs" in expected_keys:
|
||||||
|
@ -292,9 +302,36 @@ class ConfigMixin:
|
||||||
for arg in cls._flax_internal_args:
|
for arg in cls._flax_internal_args:
|
||||||
expected_keys.remove(arg)
|
expected_keys.remove(arg)
|
||||||
|
|
||||||
|
# 2. Remove attributes that cannot be expected from expected config attributes
|
||||||
# remove keys to be ignored
|
# remove keys to be ignored
|
||||||
if len(cls.ignore_for_config) > 0:
|
if len(cls.ignore_for_config) > 0:
|
||||||
expected_keys = expected_keys - set(cls.ignore_for_config)
|
expected_keys = expected_keys - set(cls.ignore_for_config)
|
||||||
|
|
||||||
|
# load diffusers library to import compatible and original scheduler
|
||||||
|
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
||||||
|
|
||||||
|
# remove attributes from compatible classes that orig cannot expect
|
||||||
|
compatible_classes = [getattr(diffusers_library, c, None) for c in cls._compatible_classes]
|
||||||
|
# filter out None potentially undefined dummy classes
|
||||||
|
compatible_classes = [c for c in compatible_classes if c is not None]
|
||||||
|
expected_keys_comp_cls = set()
|
||||||
|
for c in compatible_classes:
|
||||||
|
expected_keys_c = cls._get_init_keys(c)
|
||||||
|
expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
|
||||||
|
expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
|
||||||
|
config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
|
||||||
|
|
||||||
|
# remove attributes from orig class that cannot be expected
|
||||||
|
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
|
||||||
|
if orig_cls_name != cls.__name__:
|
||||||
|
orig_cls = getattr(diffusers_library, orig_cls_name)
|
||||||
|
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
|
||||||
|
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
|
||||||
|
|
||||||
|
# remove private attributes
|
||||||
|
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
|
||||||
|
|
||||||
|
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
|
||||||
init_dict = {}
|
init_dict = {}
|
||||||
for key in expected_keys:
|
for key in expected_keys:
|
||||||
if key in kwargs:
|
if key in kwargs:
|
||||||
|
@ -304,8 +341,7 @@ class ConfigMixin:
|
||||||
# use value from config dict
|
# use value from config dict
|
||||||
init_dict[key] = config_dict.pop(key)
|
init_dict[key] = config_dict.pop(key)
|
||||||
|
|
||||||
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
|
# 4. Give nice warning if unexpected values have been passed
|
||||||
|
|
||||||
if len(config_dict) > 0:
|
if len(config_dict) > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"The config attributes {config_dict} were passed to {cls.__name__}, "
|
f"The config attributes {config_dict} were passed to {cls.__name__}, "
|
||||||
|
@ -313,14 +349,16 @@ class ConfigMixin:
|
||||||
f"{cls.config_name} configuration file."
|
f"{cls.config_name} configuration file."
|
||||||
)
|
)
|
||||||
|
|
||||||
unused_kwargs = {**config_dict, **kwargs}
|
# 5. Give nice info if config attributes are initiliazed to default because they have not been passed
|
||||||
|
|
||||||
passed_keys = set(init_dict.keys())
|
passed_keys = set(init_dict.keys())
|
||||||
if len(expected_keys - passed_keys) > 0:
|
if len(expected_keys - passed_keys) > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
|
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 6. Define unused keyword arguments
|
||||||
|
unused_kwargs = {**config_dict, **kwargs}
|
||||||
|
|
||||||
return init_dict, unused_kwargs
|
return init_dict, unused_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -272,7 +272,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||||
>>> # Download pipeline, but overwrite scheduler
|
>>> # Download pipeline, but overwrite scheduler
|
||||||
>>> from diffusers import LMSDiscreteScheduler
|
>>> from diffusers import LMSDiscreteScheduler
|
||||||
|
|
||||||
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
|
||||||
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
|
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -360,7 +360,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
>>> # Download pipeline, but overwrite scheduler
|
>>> # Download pipeline, but overwrite scheduler
|
||||||
>>> from diffusers import LMSDiscreteScheduler
|
>>> from diffusers import LMSDiscreteScheduler
|
||||||
|
|
||||||
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
|
||||||
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
|
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
@ -602,7 +602,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
... StableDiffusionInpaintPipeline,
|
... StableDiffusionInpaintPipeline,
|
||||||
... )
|
... )
|
||||||
|
|
||||||
>>> img2text = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
>>> img2text = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||||
>>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components)
|
>>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components)
|
||||||
>>> inpaint = StableDiffusionInpaintPipeline(**img2text.components)
|
>>> inpaint = StableDiffusionInpaintPipeline(**img2text.components)
|
||||||
```
|
```
|
||||||
|
|
|
@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png")
|
||||||
# make sure you're logged in with `huggingface-cli login`
|
# make sure you're logged in with `huggingface-cli login`
|
||||||
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
||||||
|
|
||||||
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||||
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
"runwayml/stable-diffusion-v1-5",
|
"runwayml/stable-diffusion-v1-5",
|
||||||
|
@ -91,11 +91,7 @@ image.save("astronaut_rides_horse.png")
|
||||||
# make sure you're logged in with `huggingface-cli login`
|
# make sure you're logged in with `huggingface-cli login`
|
||||||
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
|
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
|
||||||
|
|
||||||
lms = LMSDiscreteScheduler(
|
lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||||
beta_start=0.00085,
|
|
||||||
beta_end=0.012,
|
|
||||||
beta_schedule="scaled_linear"
|
|
||||||
)
|
|
||||||
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
"runwayml/stable-diffusion-v1-5",
|
"runwayml/stable-diffusion-v1-5",
|
||||||
|
|
|
@ -5,6 +5,7 @@ import numpy as np
|
||||||
|
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||||
|
|
||||||
|
from ...configuration_utils import FrozenDict
|
||||||
from ...onnx_utils import OnnxRuntimeModel
|
from ...onnx_utils import OnnxRuntimeModel
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
from ...pipeline_utils import DiffusionPipeline
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
|
@ -36,6 +37,34 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||||
feature_extractor: CLIPFeatureExtractor,
|
feature_extractor: CLIPFeatureExtractor,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||||
|
deprecation_message = (
|
||||||
|
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||||
|
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||||
|
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||||
|
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||||
|
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||||
|
" file"
|
||||||
|
)
|
||||||
|
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
new_config = dict(scheduler.config)
|
||||||
|
new_config["steps_offset"] = 1
|
||||||
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
|
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||||
|
deprecation_message = (
|
||||||
|
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||||
|
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||||
|
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||||
|
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||||
|
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||||
|
)
|
||||||
|
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
new_config = dict(scheduler.config)
|
||||||
|
new_config["clip_sample"] = False
|
||||||
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
self.register_modules(
|
self.register_modules(
|
||||||
vae_encoder=vae_encoder,
|
vae_encoder=vae_encoder,
|
||||||
vae_decoder=vae_decoder,
|
vae_decoder=vae_decoder,
|
||||||
|
|
|
@ -90,6 +90,19 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
new_config["steps_offset"] = 1
|
new_config["steps_offset"] = 1
|
||||||
scheduler._internal_dict = FrozenDict(new_config)
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
|
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||||
|
deprecation_message = (
|
||||||
|
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||||
|
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||||
|
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||||
|
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||||
|
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||||
|
)
|
||||||
|
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
new_config = dict(scheduler.config)
|
||||||
|
new_config["clip_sample"] = False
|
||||||
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
if safety_checker is None:
|
if safety_checker is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||||
|
|
|
@ -104,6 +104,19 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
new_config["steps_offset"] = 1
|
new_config["steps_offset"] = 1
|
||||||
scheduler._internal_dict = FrozenDict(new_config)
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
|
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||||
|
deprecation_message = (
|
||||||
|
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||||
|
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||||
|
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||||
|
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||||
|
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||||
|
)
|
||||||
|
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
new_config = dict(scheduler.config)
|
||||||
|
new_config["clip_sample"] = False
|
||||||
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
if safety_checker is None:
|
if safety_checker is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||||
|
|
|
@ -80,6 +80,19 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||||
new_config["steps_offset"] = 1
|
new_config["steps_offset"] = 1
|
||||||
scheduler._internal_dict = FrozenDict(new_config)
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
|
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||||
|
deprecation_message = (
|
||||||
|
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||||
|
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||||
|
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||||
|
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||||
|
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||||
|
)
|
||||||
|
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
new_config = dict(scheduler.config)
|
||||||
|
new_config["clip_sample"] = False
|
||||||
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
if safety_checker is None:
|
if safety_checker is None:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||||
|
|
|
@ -91,6 +91,19 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
new_config["steps_offset"] = 1
|
new_config["steps_offset"] = 1
|
||||||
scheduler._internal_dict = FrozenDict(new_config)
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
|
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||||
|
deprecation_message = (
|
||||||
|
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||||
|
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||||
|
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||||
|
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||||
|
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||||
|
)
|
||||||
|
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
new_config = dict(scheduler.config)
|
||||||
|
new_config["clip_sample"] = False
|
||||||
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
if safety_checker is None:
|
if safety_checker is None:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||||
|
|
|
@ -90,6 +90,19 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
new_config["steps_offset"] = 1
|
new_config["steps_offset"] = 1
|
||||||
scheduler._internal_dict = FrozenDict(new_config)
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
|
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||||
|
deprecation_message = (
|
||||||
|
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||||
|
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||||
|
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||||
|
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||||
|
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||||
|
)
|
||||||
|
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
new_config = dict(scheduler.config)
|
||||||
|
new_config["clip_sample"] = False
|
||||||
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
if safety_checker is None:
|
if safety_checker is None:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||||
|
|
|
@ -96,6 +96,19 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||||
new_config["steps_offset"] = 1
|
new_config["steps_offset"] = 1
|
||||||
scheduler._internal_dict = FrozenDict(new_config)
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
|
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||||
|
deprecation_message = (
|
||||||
|
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||||
|
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||||
|
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||||
|
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||||
|
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||||
|
)
|
||||||
|
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
new_config = dict(scheduler.config)
|
||||||
|
new_config["clip_sample"] = False
|
||||||
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
if safety_checker is None:
|
if safety_checker is None:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||||
|
|
|
@ -109,6 +109,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_compatible_classes = [
|
||||||
|
"PNDMScheduler",
|
||||||
|
"DDPMScheduler",
|
||||||
|
"LMSDiscreteScheduler",
|
||||||
|
"EulerDiscreteScheduler",
|
||||||
|
"EulerAncestralDiscreteScheduler",
|
||||||
|
]
|
||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -102,6 +102,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_compatible_classes = [
|
||||||
|
"DDIMScheduler",
|
||||||
|
"PNDMScheduler",
|
||||||
|
"LMSDiscreteScheduler",
|
||||||
|
"EulerDiscreteScheduler",
|
||||||
|
"EulerAncestralDiscreteScheduler",
|
||||||
|
]
|
||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -67,6 +67,14 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_compatible_classes = [
|
||||||
|
"DDIMScheduler",
|
||||||
|
"DDPMScheduler",
|
||||||
|
"LMSDiscreteScheduler",
|
||||||
|
"PNDMScheduler",
|
||||||
|
"EulerDiscreteScheduler",
|
||||||
|
]
|
||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -68,6 +68,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_compatible_classes = [
|
||||||
|
"DDIMScheduler",
|
||||||
|
"DDPMScheduler",
|
||||||
|
"LMSDiscreteScheduler",
|
||||||
|
"PNDMScheduler",
|
||||||
|
"EulerAncestralDiscreteScheduler",
|
||||||
|
]
|
||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -67,6 +67,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_compatible_classes = [
|
||||||
|
"DDIMScheduler",
|
||||||
|
"DDPMScheduler",
|
||||||
|
"PNDMScheduler",
|
||||||
|
"EulerDiscreteScheduler",
|
||||||
|
"EulerAncestralDiscreteScheduler",
|
||||||
|
]
|
||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -88,6 +88,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_compatible_classes = [
|
||||||
|
"DDIMScheduler",
|
||||||
|
"DDPMScheduler",
|
||||||
|
"LMSDiscreteScheduler",
|
||||||
|
"EulerDiscreteScheduler",
|
||||||
|
"EulerAncestralDiscreteScheduler",
|
||||||
|
]
|
||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -644,13 +644,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
sd_pipe = sd_pipe.to(torch_device)
|
sd_pipe = sd_pipe.to(torch_device)
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
scheduler = DDIMScheduler(
|
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
|
||||||
beta_start=0.00085,
|
|
||||||
beta_end=0.012,
|
|
||||||
beta_schedule="scaled_linear",
|
|
||||||
clip_sample=False,
|
|
||||||
set_alpha_to_one=False,
|
|
||||||
)
|
|
||||||
sd_pipe.scheduler = scheduler
|
sd_pipe.scheduler = scheduler
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
|
|
@ -523,9 +523,8 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||||
init_image = init_image.resize((768, 512))
|
init_image = init_image.resize((768, 512))
|
||||||
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
||||||
|
|
||||||
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
|
||||||
|
|
||||||
model_id = "CompVis/stable-diffusion-v1-4"
|
model_id = "CompVis/stable-diffusion-v1-4"
|
||||||
|
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
|
||||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
scheduler=lms,
|
scheduler=lms,
|
||||||
|
|
|
@ -366,8 +366,8 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
||||||
|
|
||||||
pndm = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True)
|
|
||||||
model_id = "runwayml/stable-diffusion-inpainting"
|
model_id = "runwayml/stable-diffusion-inpainting"
|
||||||
|
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
|
||||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
model_id, safety_checker=None, scheduler=pndm, device_map="auto"
|
model_id, safety_checker=None, scheduler=pndm, device_map="auto"
|
||||||
)
|
)
|
||||||
|
|
|
@ -407,9 +407,8 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
||||||
|
|
||||||
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
|
||||||
|
|
||||||
model_id = "CompVis/stable-diffusion-v1-4"
|
model_id = "CompVis/stable-diffusion-v1-4"
|
||||||
|
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
|
||||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
scheduler=lms,
|
scheduler=lms,
|
||||||
|
|
|
@ -13,10 +13,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import diffusers
|
||||||
|
from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, PNDMScheduler, logging
|
||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from diffusers.utils.testing_utils import CaptureLogger
|
||||||
|
|
||||||
|
|
||||||
class SampleObject(ConfigMixin):
|
class SampleObject(ConfigMixin):
|
||||||
|
@ -34,6 +39,37 @@ class SampleObject(ConfigMixin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SampleObject2(ConfigMixin):
|
||||||
|
config_name = "config.json"
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
a=2,
|
||||||
|
b=5,
|
||||||
|
c=(2, 5),
|
||||||
|
d="for diffusion",
|
||||||
|
f=[1, 3],
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SampleObject3(ConfigMixin):
|
||||||
|
config_name = "config.json"
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
a=2,
|
||||||
|
b=5,
|
||||||
|
c=(2, 5),
|
||||||
|
d="for diffusion",
|
||||||
|
e=[1, 3],
|
||||||
|
f=[1, 3],
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ConfigTester(unittest.TestCase):
|
class ConfigTester(unittest.TestCase):
|
||||||
def test_load_not_from_mixin(self):
|
def test_load_not_from_mixin(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
@ -97,3 +133,151 @@ class ConfigTester(unittest.TestCase):
|
||||||
assert config.pop("c") == (2, 5) # instantiated as tuple
|
assert config.pop("c") == (2, 5) # instantiated as tuple
|
||||||
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
|
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
|
||||||
assert config == new_config
|
assert config == new_config
|
||||||
|
|
||||||
|
def test_save_load_from_different_config(self):
|
||||||
|
obj = SampleObject()
|
||||||
|
|
||||||
|
# mock add obj class to `diffusers`
|
||||||
|
setattr(diffusers, "SampleObject", SampleObject)
|
||||||
|
logger = logging.get_logger("diffusers.configuration_utils")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
obj.save_config(tmpdirname)
|
||||||
|
with CaptureLogger(logger) as cap_logger_1:
|
||||||
|
new_obj_1 = SampleObject2.from_config(tmpdirname)
|
||||||
|
|
||||||
|
# now save a config parameter that is not expected
|
||||||
|
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
data["unexpected"] = True
|
||||||
|
|
||||||
|
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cap_logger_2:
|
||||||
|
new_obj_2 = SampleObject.from_config(tmpdirname)
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cap_logger_3:
|
||||||
|
new_obj_3 = SampleObject2.from_config(tmpdirname)
|
||||||
|
|
||||||
|
assert new_obj_1.__class__ == SampleObject2
|
||||||
|
assert new_obj_2.__class__ == SampleObject
|
||||||
|
assert new_obj_3.__class__ == SampleObject2
|
||||||
|
|
||||||
|
assert cap_logger_1.out == ""
|
||||||
|
assert (
|
||||||
|
cap_logger_2.out
|
||||||
|
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
|
||||||
|
" be ignored. Please verify your config.json configuration file.\n"
|
||||||
|
)
|
||||||
|
assert cap_logger_2.out.replace("SampleObject", "SampleObject2") == cap_logger_3.out
|
||||||
|
|
||||||
|
def test_save_load_compatible_schedulers(self):
|
||||||
|
SampleObject2._compatible_classes = ["SampleObject"]
|
||||||
|
SampleObject._compatible_classes = ["SampleObject2"]
|
||||||
|
|
||||||
|
obj = SampleObject()
|
||||||
|
|
||||||
|
# mock add obj class to `diffusers`
|
||||||
|
setattr(diffusers, "SampleObject", SampleObject)
|
||||||
|
setattr(diffusers, "SampleObject2", SampleObject2)
|
||||||
|
logger = logging.get_logger("diffusers.configuration_utils")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
obj.save_config(tmpdirname)
|
||||||
|
|
||||||
|
# now save a config parameter that is expected by another class, but not origin class
|
||||||
|
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
data["f"] = [0, 0]
|
||||||
|
data["unexpected"] = True
|
||||||
|
|
||||||
|
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cap_logger:
|
||||||
|
new_obj = SampleObject.from_config(tmpdirname)
|
||||||
|
|
||||||
|
assert new_obj.__class__ == SampleObject
|
||||||
|
|
||||||
|
assert (
|
||||||
|
cap_logger.out
|
||||||
|
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
|
||||||
|
" be ignored. Please verify your config.json configuration file.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_save_load_from_different_config_comp_schedulers(self):
|
||||||
|
SampleObject3._compatible_classes = ["SampleObject", "SampleObject2"]
|
||||||
|
SampleObject2._compatible_classes = ["SampleObject", "SampleObject3"]
|
||||||
|
SampleObject._compatible_classes = ["SampleObject2", "SampleObject3"]
|
||||||
|
|
||||||
|
obj = SampleObject()
|
||||||
|
|
||||||
|
# mock add obj class to `diffusers`
|
||||||
|
setattr(diffusers, "SampleObject", SampleObject)
|
||||||
|
setattr(diffusers, "SampleObject2", SampleObject2)
|
||||||
|
setattr(diffusers, "SampleObject3", SampleObject3)
|
||||||
|
logger = logging.get_logger("diffusers.configuration_utils")
|
||||||
|
logger.setLevel(diffusers.logging.INFO)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
obj.save_config(tmpdirname)
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cap_logger_1:
|
||||||
|
new_obj_1 = SampleObject.from_config(tmpdirname)
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cap_logger_2:
|
||||||
|
new_obj_2 = SampleObject2.from_config(tmpdirname)
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cap_logger_3:
|
||||||
|
new_obj_3 = SampleObject3.from_config(tmpdirname)
|
||||||
|
|
||||||
|
assert new_obj_1.__class__ == SampleObject
|
||||||
|
assert new_obj_2.__class__ == SampleObject2
|
||||||
|
assert new_obj_3.__class__ == SampleObject3
|
||||||
|
|
||||||
|
assert cap_logger_1.out == ""
|
||||||
|
assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
|
||||||
|
assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
|
||||||
|
|
||||||
|
def test_load_ddim_from_pndm(self):
|
||||||
|
logger = logging.get_logger("diffusers.configuration_utils")
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cap_logger:
|
||||||
|
ddim = DDIMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
|
||||||
|
|
||||||
|
assert ddim.__class__ == DDIMScheduler
|
||||||
|
# no warning should be thrown
|
||||||
|
assert cap_logger.out == ""
|
||||||
|
|
||||||
|
def test_load_ddim_from_euler(self):
|
||||||
|
logger = logging.get_logger("diffusers.configuration_utils")
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cap_logger:
|
||||||
|
euler = EulerDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
|
||||||
|
|
||||||
|
assert euler.__class__ == EulerDiscreteScheduler
|
||||||
|
# no warning should be thrown
|
||||||
|
assert cap_logger.out == ""
|
||||||
|
|
||||||
|
def test_load_ddim_from_euler_ancestral(self):
|
||||||
|
logger = logging.get_logger("diffusers.configuration_utils")
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cap_logger:
|
||||||
|
euler = EulerAncestralDiscreteScheduler.from_config(
|
||||||
|
"runwayml/stable-diffusion-v1-5", subfolder="scheduler"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert euler.__class__ == EulerAncestralDiscreteScheduler
|
||||||
|
# no warning should be thrown
|
||||||
|
assert cap_logger.out == ""
|
||||||
|
|
||||||
|
def test_load_pndm(self):
|
||||||
|
logger = logging.get_logger("diffusers.configuration_utils")
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cap_logger:
|
||||||
|
pndm = PNDMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
|
||||||
|
|
||||||
|
assert pndm.__class__ == PNDMScheduler
|
||||||
|
# no warning should be thrown
|
||||||
|
assert cap_logger.out == ""
|
||||||
|
|
Loading…
Reference in New Issue