[Better scheduler docs] Improve usage examples of schedulers (#890)

* [Better scheduler docs] Improve usage examples of schedulers

* finish

* fix warnings and add test

* finish

* more replacements

* adapt fast tests hf token

* correct more

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Integrate compatibility with euler

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Patrick von Platen 2022-10-31 17:26:30 +01:00 committed by GitHub
parent a1ea8c01c3
commit c18941b01a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 402 additions and 53 deletions

View File

@ -42,6 +42,8 @@ jobs:
python utils/print_env.py python utils/print_env.py
- name: Run all fast tests on CPU - name: Run all fast tests on CPU
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: | run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_cpu tests/ python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_cpu tests/
@ -91,6 +93,8 @@ jobs:
- name: Run all fast tests on MPS - name: Run all fast tests on MPS
shell: arch -arch arm64 bash {0} shell: arch -arch arm64 bash {0}
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: | run: |
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/ ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/

View File

@ -142,11 +142,7 @@ it before the pipeline and pass it to `from_pretrained`.
```python ```python
from diffusers import LMSDiscreteScheduler from diffusers import LMSDiscreteScheduler
lms = LMSDiscreteScheduler( lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear"
)
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", "runwayml/stable-diffusion-v1-5",

View File

@ -121,7 +121,7 @@ you could use it as follows:
```python ```python
>>> from diffusers import LMSDiscreteScheduler >>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") >>> scheduler = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
>>> generator = StableDiffusionPipeline.from_pretrained( >>> generator = StableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN ... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN

View File

@ -469,9 +469,7 @@ def main(args):
eps=args.adam_epsilon, eps=args.adam_epsilon,
) )
noise_scheduler = DDPMScheduler( noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
train_dataset = DreamBoothDataset( train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir, instance_data_root=args.instance_data_dir,

View File

@ -372,11 +372,7 @@ def main():
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon, eps=args.adam_epsilon,
) )
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
# TODO (patil-suraj): load scheduler using args
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
# Get the datasets: you can either provide your own training and evaluation files (see below) # Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
@ -609,9 +605,7 @@ def main():
vae=vae, vae=vae,
unet=unet, unet=unet,
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=PNDMScheduler( scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
) )

View File

@ -419,13 +419,7 @@ def main():
eps=args.adam_epsilon, eps=args.adam_epsilon,
) )
# TODO (patil-suraj): load scheduler using args noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
train_dataset = TextualInversionDataset( train_dataset = TextualInversionDataset(
data_root=args.train_data_dir, data_root=args.train_data_dir,
@ -558,9 +552,7 @@ def main():
vae=vae, vae=vae,
unet=unet, unet=unet,
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=PNDMScheduler( scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
) )

View File

@ -16,6 +16,7 @@
""" ConfigMixin base class and utilities.""" """ ConfigMixin base class and utilities."""
import dataclasses import dataclasses
import functools import functools
import importlib
import inspect import inspect
import json import json
import os import os
@ -48,9 +49,13 @@ class ConfigMixin:
[`~ConfigMixin.save_config`] (should be overridden by parent class). [`~ConfigMixin.save_config`] (should be overridden by parent class).
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
overridden by parent class). overridden by parent class).
- **_compatible_classes** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
`from_config` can be used from a class different than the one used to save the config (should be overridden
by parent class).
""" """
config_name = None config_name = None
ignore_for_config = [] ignore_for_config = []
_compatible_classes = []
def register_to_config(self, **kwargs): def register_to_config(self, **kwargs):
if self.config_name is None: if self.config_name is None:
@ -280,9 +285,14 @@ class ConfigMixin:
return config_dict return config_dict
@staticmethod
def _get_init_keys(cls):
return set(dict(inspect.signature(cls.__init__).parameters).keys())
@classmethod @classmethod
def extract_init_dict(cls, config_dict, **kwargs): def extract_init_dict(cls, config_dict, **kwargs):
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys()) # 1. Retrieve expected config attributes from __init__ signature
expected_keys = cls._get_init_keys(cls)
expected_keys.remove("self") expected_keys.remove("self")
# remove general kwargs if present in dict # remove general kwargs if present in dict
if "kwargs" in expected_keys: if "kwargs" in expected_keys:
@ -292,9 +302,36 @@ class ConfigMixin:
for arg in cls._flax_internal_args: for arg in cls._flax_internal_args:
expected_keys.remove(arg) expected_keys.remove(arg)
# 2. Remove attributes that cannot be expected from expected config attributes
# remove keys to be ignored # remove keys to be ignored
if len(cls.ignore_for_config) > 0: if len(cls.ignore_for_config) > 0:
expected_keys = expected_keys - set(cls.ignore_for_config) expected_keys = expected_keys - set(cls.ignore_for_config)
# load diffusers library to import compatible and original scheduler
diffusers_library = importlib.import_module(__name__.split(".")[0])
# remove attributes from compatible classes that orig cannot expect
compatible_classes = [getattr(diffusers_library, c, None) for c in cls._compatible_classes]
# filter out None potentially undefined dummy classes
compatible_classes = [c for c in compatible_classes if c is not None]
expected_keys_comp_cls = set()
for c in compatible_classes:
expected_keys_c = cls._get_init_keys(c)
expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
# remove attributes from orig class that cannot be expected
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
if orig_cls_name != cls.__name__:
orig_cls = getattr(diffusers_library, orig_cls_name)
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
# remove private attributes
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
init_dict = {} init_dict = {}
for key in expected_keys: for key in expected_keys:
if key in kwargs: if key in kwargs:
@ -304,8 +341,7 @@ class ConfigMixin:
# use value from config dict # use value from config dict
init_dict[key] = config_dict.pop(key) init_dict[key] = config_dict.pop(key)
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} # 4. Give nice warning if unexpected values have been passed
if len(config_dict) > 0: if len(config_dict) > 0:
logger.warning( logger.warning(
f"The config attributes {config_dict} were passed to {cls.__name__}, " f"The config attributes {config_dict} were passed to {cls.__name__}, "
@ -313,14 +349,16 @@ class ConfigMixin:
f"{cls.config_name} configuration file." f"{cls.config_name} configuration file."
) )
unused_kwargs = {**config_dict, **kwargs} # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
passed_keys = set(init_dict.keys()) passed_keys = set(init_dict.keys())
if len(expected_keys - passed_keys) > 0: if len(expected_keys - passed_keys) > 0:
logger.info( logger.info(
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
) )
# 6. Define unused keyword arguments
unused_kwargs = {**config_dict, **kwargs}
return init_dict, unused_kwargs return init_dict, unused_kwargs
@classmethod @classmethod

View File

@ -272,7 +272,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
>>> # Download pipeline, but overwrite scheduler >>> # Download pipeline, but overwrite scheduler
>>> from diffusers import LMSDiscreteScheduler >>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") >>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler) >>> pipeline = FlaxDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
``` ```
""" """

View File

@ -360,7 +360,7 @@ class DiffusionPipeline(ConfigMixin):
>>> # Download pipeline, but overwrite scheduler >>> # Download pipeline, but overwrite scheduler
>>> from diffusers import LMSDiscreteScheduler >>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") >>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler) >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
``` ```
""" """
@ -602,7 +602,7 @@ class DiffusionPipeline(ConfigMixin):
... StableDiffusionInpaintPipeline, ... StableDiffusionInpaintPipeline,
... ) ... )
>>> img2text = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") >>> img2text = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components) >>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components)
>>> inpaint = StableDiffusionInpaintPipeline(**img2text.components) >>> inpaint = StableDiffusionInpaintPipeline(**img2text.components)
``` ```

View File

@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login` # make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline, DDIMScheduler from diffusers import StableDiffusionPipeline, DDIMScheduler
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", "runwayml/stable-diffusion-v1-5",
@ -91,11 +91,7 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login` # make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
lms = LMSDiscreteScheduler( lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear"
)
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", "runwayml/stable-diffusion-v1-5",

View File

@ -5,6 +5,7 @@ import numpy as np
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...onnx_utils import OnnxRuntimeModel from ...onnx_utils import OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
@ -36,6 +37,34 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
): ):
super().__init__() super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae_encoder=vae_encoder, vae_encoder=vae_encoder,
vae_decoder=vae_decoder, vae_decoder=vae_decoder,

View File

@ -90,6 +90,19 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None:
logger.warning( logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"

View File

@ -104,6 +104,19 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None:
logger.warning( logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"

View File

@ -80,6 +80,19 @@ class StableDiffusionPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None:
logger.warn( logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"

View File

@ -91,6 +91,19 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None:
logger.warn( logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"

View File

@ -90,6 +90,19 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None:
logger.warn( logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"

View File

@ -96,6 +96,19 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None:
logger.warn( logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"

View File

@ -109,6 +109,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [
"PNDMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,

View File

@ -102,6 +102,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [
"DDIMScheduler",
"PNDMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,

View File

@ -67,6 +67,14 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"PNDMScheduler",
"EulerDiscreteScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,

View File

@ -68,6 +68,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"PNDMScheduler",
"EulerAncestralDiscreteScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,

View File

@ -67,6 +67,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"PNDMScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,

View File

@ -88,6 +88,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,

View File

@ -644,13 +644,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
sd_pipe = sd_pipe.to(torch_device) sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None) sd_pipe.set_progress_bar_config(disable=None)
scheduler = DDIMScheduler( scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
sd_pipe.scheduler = scheduler sd_pipe.scheduler = scheduler
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"

View File

@ -523,9 +523,8 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
init_image = init_image.resize((768, 512)) init_image = init_image.resize((768, 512))
expected_image = np.array(expected_image, dtype=np.float32) / 255.0 expected_image = np.array(expected_image, dtype=np.float32) / 255.0
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id, model_id,
scheduler=lms, scheduler=lms,

View File

@ -366,8 +366,8 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
) )
expected_image = np.array(expected_image, dtype=np.float32) / 255.0 expected_image = np.array(expected_image, dtype=np.float32) / 255.0
pndm = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True)
model_id = "runwayml/stable-diffusion-inpainting" model_id = "runwayml/stable-diffusion-inpainting"
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained( pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id, safety_checker=None, scheduler=pndm, device_map="auto" model_id, safety_checker=None, scheduler=pndm, device_map="auto"
) )

View File

@ -407,9 +407,8 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
) )
expected_image = np.array(expected_image, dtype=np.float32) / 255.0 expected_image = np.array(expected_image, dtype=np.float32) / 255.0
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained( pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id, model_id,
scheduler=lms, scheduler=lms,

View File

@ -13,10 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import os
import tempfile import tempfile
import unittest import unittest
import diffusers
from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, PNDMScheduler, logging
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils.testing_utils import CaptureLogger
class SampleObject(ConfigMixin): class SampleObject(ConfigMixin):
@ -34,6 +39,37 @@ class SampleObject(ConfigMixin):
pass pass
class SampleObject2(ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
f=[1, 3],
):
pass
class SampleObject3(ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
e=[1, 3],
f=[1, 3],
):
pass
class ConfigTester(unittest.TestCase): class ConfigTester(unittest.TestCase):
def test_load_not_from_mixin(self): def test_load_not_from_mixin(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -97,3 +133,151 @@ class ConfigTester(unittest.TestCase):
assert config.pop("c") == (2, 5) # instantiated as tuple assert config.pop("c") == (2, 5) # instantiated as tuple
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
assert config == new_config assert config == new_config
def test_save_load_from_different_config(self):
obj = SampleObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SampleObject", SampleObject)
logger = logging.get_logger("diffusers.configuration_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_1:
new_obj_1 = SampleObject2.from_config(tmpdirname)
# now save a config parameter that is not expected
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
data = json.load(f)
data["unexpected"] = True
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
json.dump(data, f)
with CaptureLogger(logger) as cap_logger_2:
new_obj_2 = SampleObject.from_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_3:
new_obj_3 = SampleObject2.from_config(tmpdirname)
assert new_obj_1.__class__ == SampleObject2
assert new_obj_2.__class__ == SampleObject
assert new_obj_3.__class__ == SampleObject2
assert cap_logger_1.out == ""
assert (
cap_logger_2.out
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
" be ignored. Please verify your config.json configuration file.\n"
)
assert cap_logger_2.out.replace("SampleObject", "SampleObject2") == cap_logger_3.out
def test_save_load_compatible_schedulers(self):
SampleObject2._compatible_classes = ["SampleObject"]
SampleObject._compatible_classes = ["SampleObject2"]
obj = SampleObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SampleObject", SampleObject)
setattr(diffusers, "SampleObject2", SampleObject2)
logger = logging.get_logger("diffusers.configuration_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
# now save a config parameter that is expected by another class, but not origin class
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
data = json.load(f)
data["f"] = [0, 0]
data["unexpected"] = True
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
json.dump(data, f)
with CaptureLogger(logger) as cap_logger:
new_obj = SampleObject.from_config(tmpdirname)
assert new_obj.__class__ == SampleObject
assert (
cap_logger.out
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
" be ignored. Please verify your config.json configuration file.\n"
)
def test_save_load_from_different_config_comp_schedulers(self):
SampleObject3._compatible_classes = ["SampleObject", "SampleObject2"]
SampleObject2._compatible_classes = ["SampleObject", "SampleObject3"]
SampleObject._compatible_classes = ["SampleObject2", "SampleObject3"]
obj = SampleObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SampleObject", SampleObject)
setattr(diffusers, "SampleObject2", SampleObject2)
setattr(diffusers, "SampleObject3", SampleObject3)
logger = logging.get_logger("diffusers.configuration_utils")
logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_1:
new_obj_1 = SampleObject.from_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_2:
new_obj_2 = SampleObject2.from_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_3:
new_obj_3 = SampleObject3.from_config(tmpdirname)
assert new_obj_1.__class__ == SampleObject
assert new_obj_2.__class__ == SampleObject2
assert new_obj_3.__class__ == SampleObject3
assert cap_logger_1.out == ""
assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
def test_load_ddim_from_pndm(self):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
ddim = DDIMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
assert ddim.__class__ == DDIMScheduler
# no warning should be thrown
assert cap_logger.out == ""
def test_load_ddim_from_euler(self):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
euler = EulerDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
assert euler.__class__ == EulerDiscreteScheduler
# no warning should be thrown
assert cap_logger.out == ""
def test_load_ddim_from_euler_ancestral(self):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
euler = EulerAncestralDiscreteScheduler.from_config(
"runwayml/stable-diffusion-v1-5", subfolder="scheduler"
)
assert euler.__class__ == EulerAncestralDiscreteScheduler
# no warning should be thrown
assert cap_logger.out == ""
def test_load_pndm(self):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
pndm = PNDMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
assert pndm.__class__ == PNDMScheduler
# no warning should be thrown
assert cap_logger.out == ""