286 lines
9.7 KiB
Python
Executable File
286 lines
9.7 KiB
Python
Executable File
# coding=utf-8
|
|
# Copyright 2022 HuggingFace Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import json
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
|
|
import diffusers
|
|
from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, PNDMScheduler, logging
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
from diffusers.utils.testing_utils import CaptureLogger
|
|
|
|
|
|
class SampleObject(ConfigMixin):
|
|
config_name = "config.json"
|
|
|
|
@register_to_config
|
|
def __init__(
|
|
self,
|
|
a=2,
|
|
b=5,
|
|
c=(2, 5),
|
|
d="for diffusion",
|
|
e=[1, 3],
|
|
):
|
|
pass
|
|
|
|
|
|
class SampleObject2(ConfigMixin):
|
|
config_name = "config.json"
|
|
|
|
@register_to_config
|
|
def __init__(
|
|
self,
|
|
a=2,
|
|
b=5,
|
|
c=(2, 5),
|
|
d="for diffusion",
|
|
f=[1, 3],
|
|
):
|
|
pass
|
|
|
|
|
|
class SampleObject3(ConfigMixin):
|
|
config_name = "config.json"
|
|
|
|
@register_to_config
|
|
def __init__(
|
|
self,
|
|
a=2,
|
|
b=5,
|
|
c=(2, 5),
|
|
d="for diffusion",
|
|
e=[1, 3],
|
|
f=[1, 3],
|
|
):
|
|
pass
|
|
|
|
|
|
class ConfigTester(unittest.TestCase):
|
|
def test_load_not_from_mixin(self):
|
|
with self.assertRaises(ValueError):
|
|
ConfigMixin.from_config("dummy_path")
|
|
|
|
def test_register_to_config(self):
|
|
obj = SampleObject()
|
|
config = obj.config
|
|
assert config["a"] == 2
|
|
assert config["b"] == 5
|
|
assert config["c"] == (2, 5)
|
|
assert config["d"] == "for diffusion"
|
|
assert config["e"] == [1, 3]
|
|
|
|
# init ignore private arguments
|
|
obj = SampleObject(_name_or_path="lalala")
|
|
config = obj.config
|
|
assert config["a"] == 2
|
|
assert config["b"] == 5
|
|
assert config["c"] == (2, 5)
|
|
assert config["d"] == "for diffusion"
|
|
assert config["e"] == [1, 3]
|
|
|
|
# can override default
|
|
obj = SampleObject(c=6)
|
|
config = obj.config
|
|
assert config["a"] == 2
|
|
assert config["b"] == 5
|
|
assert config["c"] == 6
|
|
assert config["d"] == "for diffusion"
|
|
assert config["e"] == [1, 3]
|
|
|
|
# can use positional arguments.
|
|
obj = SampleObject(1, c=6)
|
|
config = obj.config
|
|
assert config["a"] == 1
|
|
assert config["b"] == 5
|
|
assert config["c"] == 6
|
|
assert config["d"] == "for diffusion"
|
|
assert config["e"] == [1, 3]
|
|
|
|
def test_save_load(self):
|
|
obj = SampleObject()
|
|
config = obj.config
|
|
|
|
assert config["a"] == 2
|
|
assert config["b"] == 5
|
|
assert config["c"] == (2, 5)
|
|
assert config["d"] == "for diffusion"
|
|
assert config["e"] == [1, 3]
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
obj.save_config(tmpdirname)
|
|
new_obj = SampleObject.from_config(tmpdirname)
|
|
new_config = new_obj.config
|
|
|
|
# unfreeze configs
|
|
config = dict(config)
|
|
new_config = dict(new_config)
|
|
|
|
assert config.pop("c") == (2, 5) # instantiated as tuple
|
|
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
|
|
assert config == new_config
|
|
|
|
def test_save_load_from_different_config(self):
|
|
obj = SampleObject()
|
|
|
|
# mock add obj class to `diffusers`
|
|
setattr(diffusers, "SampleObject", SampleObject)
|
|
logger = logging.get_logger("diffusers.configuration_utils")
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
obj.save_config(tmpdirname)
|
|
with CaptureLogger(logger) as cap_logger_1:
|
|
new_obj_1 = SampleObject2.from_config(tmpdirname)
|
|
|
|
# now save a config parameter that is not expected
|
|
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
|
|
data = json.load(f)
|
|
data["unexpected"] = True
|
|
|
|
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
|
|
json.dump(data, f)
|
|
|
|
with CaptureLogger(logger) as cap_logger_2:
|
|
new_obj_2 = SampleObject.from_config(tmpdirname)
|
|
|
|
with CaptureLogger(logger) as cap_logger_3:
|
|
new_obj_3 = SampleObject2.from_config(tmpdirname)
|
|
|
|
assert new_obj_1.__class__ == SampleObject2
|
|
assert new_obj_2.__class__ == SampleObject
|
|
assert new_obj_3.__class__ == SampleObject2
|
|
|
|
assert cap_logger_1.out == ""
|
|
assert (
|
|
cap_logger_2.out
|
|
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
|
|
" be ignored. Please verify your config.json configuration file.\n"
|
|
)
|
|
assert cap_logger_2.out.replace("SampleObject", "SampleObject2") == cap_logger_3.out
|
|
|
|
def test_save_load_compatible_schedulers(self):
|
|
SampleObject2._compatible_classes = ["SampleObject"]
|
|
SampleObject._compatible_classes = ["SampleObject2"]
|
|
|
|
obj = SampleObject()
|
|
|
|
# mock add obj class to `diffusers`
|
|
setattr(diffusers, "SampleObject", SampleObject)
|
|
setattr(diffusers, "SampleObject2", SampleObject2)
|
|
logger = logging.get_logger("diffusers.configuration_utils")
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
obj.save_config(tmpdirname)
|
|
|
|
# now save a config parameter that is expected by another class, but not origin class
|
|
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
|
|
data = json.load(f)
|
|
data["f"] = [0, 0]
|
|
data["unexpected"] = True
|
|
|
|
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
|
|
json.dump(data, f)
|
|
|
|
with CaptureLogger(logger) as cap_logger:
|
|
new_obj = SampleObject.from_config(tmpdirname)
|
|
|
|
assert new_obj.__class__ == SampleObject
|
|
|
|
assert (
|
|
cap_logger.out
|
|
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
|
|
" be ignored. Please verify your config.json configuration file.\n"
|
|
)
|
|
|
|
def test_save_load_from_different_config_comp_schedulers(self):
|
|
SampleObject3._compatible_classes = ["SampleObject", "SampleObject2"]
|
|
SampleObject2._compatible_classes = ["SampleObject", "SampleObject3"]
|
|
SampleObject._compatible_classes = ["SampleObject2", "SampleObject3"]
|
|
|
|
obj = SampleObject()
|
|
|
|
# mock add obj class to `diffusers`
|
|
setattr(diffusers, "SampleObject", SampleObject)
|
|
setattr(diffusers, "SampleObject2", SampleObject2)
|
|
setattr(diffusers, "SampleObject3", SampleObject3)
|
|
logger = logging.get_logger("diffusers.configuration_utils")
|
|
logger.setLevel(diffusers.logging.INFO)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
obj.save_config(tmpdirname)
|
|
|
|
with CaptureLogger(logger) as cap_logger_1:
|
|
new_obj_1 = SampleObject.from_config(tmpdirname)
|
|
|
|
with CaptureLogger(logger) as cap_logger_2:
|
|
new_obj_2 = SampleObject2.from_config(tmpdirname)
|
|
|
|
with CaptureLogger(logger) as cap_logger_3:
|
|
new_obj_3 = SampleObject3.from_config(tmpdirname)
|
|
|
|
assert new_obj_1.__class__ == SampleObject
|
|
assert new_obj_2.__class__ == SampleObject2
|
|
assert new_obj_3.__class__ == SampleObject3
|
|
|
|
assert cap_logger_1.out == ""
|
|
assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
|
|
assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
|
|
|
|
def test_load_ddim_from_pndm(self):
|
|
logger = logging.get_logger("diffusers.configuration_utils")
|
|
|
|
with CaptureLogger(logger) as cap_logger:
|
|
ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
|
|
|
assert ddim.__class__ == DDIMScheduler
|
|
# no warning should be thrown
|
|
assert cap_logger.out == ""
|
|
|
|
def test_load_euler_from_pndm(self):
|
|
logger = logging.get_logger("diffusers.configuration_utils")
|
|
|
|
with CaptureLogger(logger) as cap_logger:
|
|
euler = EulerDiscreteScheduler.from_config(
|
|
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
|
|
)
|
|
|
|
assert euler.__class__ == EulerDiscreteScheduler
|
|
# no warning should be thrown
|
|
assert cap_logger.out == ""
|
|
|
|
def test_load_euler_ancestral_from_pndm(self):
|
|
logger = logging.get_logger("diffusers.configuration_utils")
|
|
|
|
with CaptureLogger(logger) as cap_logger:
|
|
euler = EulerAncestralDiscreteScheduler.from_config(
|
|
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
|
|
)
|
|
|
|
assert euler.__class__ == EulerAncestralDiscreteScheduler
|
|
# no warning should be thrown
|
|
assert cap_logger.out == ""
|
|
|
|
def test_load_pndm(self):
|
|
logger = logging.get_logger("diffusers.configuration_utils")
|
|
|
|
with CaptureLogger(logger) as cap_logger:
|
|
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
|
|
|
assert pndm.__class__ == PNDMScheduler
|
|
# no warning should be thrown
|
|
assert cap_logger.out == ""
|